diff --git a/ntex-codec/CHANGES.md b/ntex-codec/CHANGES.md index 2514c79b..29100ec9 100644 --- a/ntex-codec/CHANGES.md +++ b/ntex-codec/CHANGES.md @@ -1,8 +1,14 @@ # Changes +## [0.6.0] - 2021-12-xx + +* Removed Framed type + +* Removed tokio dependency + ## [0.5.1] - 2021-09-08 -* Fix tight loop in Framed::close() method. +* Fix tight loop in Framed::close() method ## [0.5.0] - 2021-06-27 diff --git a/ntex-codec/Cargo.toml b/ntex-codec/Cargo.toml index 84d70d91..13e835a1 100644 --- a/ntex-codec/Cargo.toml +++ b/ntex-codec/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-codec" -version = "0.5.1" +version = "0.6.0" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] @@ -16,12 +16,4 @@ name = "ntex_codec" path = "src/lib.rs" [dependencies] -bitflags = "1.3" ntex-bytes = "0.1" -ntex-util = "0.1" -log = "0.4" -tokio = { version = "1", default-features = false } - -[dev-dependencies] -ntex = "0.4.13" -futures = "0.3.13" diff --git a/ntex-codec/src/framed.rs b/ntex-codec/src/framed.rs deleted file mode 100644 index fb45aabe..00000000 --- a/ntex-codec/src/framed.rs +++ /dev/null @@ -1,691 +0,0 @@ -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{fmt, io}; - -use ntex_bytes::{Buf, BytesMut}; -use ntex_util::{future::Either, ready, Sink, Stream}; - -use crate::{AsyncRead, AsyncWrite, Decoder, Encoder}; - -const LW: usize = 1024; -const HW: usize = 8 * 1024; - -bitflags::bitflags! { - struct Flags: u8 { - const EOF = 0b0001; - const READABLE = 0b0010; - const DISCONNECTED = 0b0100; - const SHUTDOWN = 0b1000; - } -} - -/// A unified interface to an underlying I/O object, using -/// the `Encoder` and `Decoder` traits to encode and decode frames. -/// `Framed` is heavily optimized for streaming io. -pub struct Framed { - io: T, - codec: U, - flags: Flags, - read_buf: BytesMut, - write_buf: BytesMut, - err: Option, -} - -impl Framed -where - T: AsyncRead + AsyncWrite, - U: Decoder + Encoder, -{ - #[inline] - /// Provides an interface for reading and writing to - /// `Io` object, using `Decode` and `Encode` traits of codec. - /// - /// Raw I/O objects work with byte sequences, but higher-level code usually - /// wants to batch these into meaningful chunks, called "frames". This - /// method layers framing on top of an I/O object, by using the `Codec` - /// traits to handle encoding and decoding of messages frames. Note that - /// the incoming and outgoing frame types may be distinct. - pub fn new(io: T, codec: U) -> Framed { - Framed { - io, - codec, - err: None, - flags: Flags::empty(), - read_buf: BytesMut::with_capacity(HW), - write_buf: BytesMut::with_capacity(HW), - } - } -} - -impl Framed { - #[inline] - /// Construct `Framed` object `parts`. - pub fn from_parts(parts: FramedParts) -> Framed { - Framed { - io: parts.io, - codec: parts.codec, - flags: parts.flags, - write_buf: parts.write_buf, - read_buf: parts.read_buf, - err: parts.err, - } - } - - #[inline] - /// Returns a reference to the underlying codec. - pub fn get_codec(&self) -> &U { - &self.codec - } - - #[inline] - /// Returns a mutable reference to the underlying codec. - pub fn get_codec_mut(&mut self) -> &mut U { - &mut self.codec - } - - #[inline] - /// Returns a reference to the underlying I/O stream wrapped by `Framed`. - /// - /// Note that care should be taken to not tamper with the underlying stream - /// of data coming in as it may corrupt the stream of frames otherwise - /// being worked with. - pub fn get_ref(&self) -> &T { - &self.io - } - - #[inline] - /// Returns a mutable reference to the underlying I/O stream wrapped by - /// `Framed`. - /// - /// Note that care should be taken to not tamper with the underlying stream - /// of data coming in as it may corrupt the stream of frames otherwise - /// being worked with. - pub fn get_mut(&mut self) -> &mut T { - &mut self.io - } - - #[inline] - /// Get read buffer. - pub fn read_buf(&mut self) -> &mut BytesMut { - &mut self.read_buf - } - - #[inline] - /// Get write buffer. - pub fn write_buf(&mut self) -> &mut BytesMut { - &mut self.write_buf - } - - #[inline] - /// Check if write buffer is empty. - pub fn is_write_buf_empty(&self) -> bool { - self.write_buf.is_empty() - } - - #[inline] - /// Check if write buffer is full. - pub fn is_write_buf_full(&self) -> bool { - self.write_buf.len() >= HW - } - - #[inline] - /// Check if framed object is closed - pub fn is_closed(&self) -> bool { - self.flags.contains(Flags::DISCONNECTED) - } - - #[inline] - /// Consume the `Frame`, returning `Frame` with different codec. - pub fn into_framed(self, codec: U2) -> Framed { - Framed { - codec, - io: self.io, - flags: self.flags, - read_buf: self.read_buf, - write_buf: self.write_buf, - err: self.err, - } - } - - #[inline] - /// Consume the `Frame`, returning `Frame` with different io. - pub fn map_io(self, f: F) -> Framed - where - F: Fn(T) -> T2, - { - Framed { - io: f(self.io), - codec: self.codec, - flags: self.flags, - read_buf: self.read_buf, - write_buf: self.write_buf, - err: self.err, - } - } - - #[inline] - /// Consume the `Frame`, returning `Frame` with different codec. - pub fn map_codec(self, f: F) -> Framed - where - F: Fn(U) -> U2, - { - Framed { - io: self.io, - codec: f(self.codec), - flags: self.flags, - read_buf: self.read_buf, - write_buf: self.write_buf, - err: self.err, - } - } - - #[inline] - /// Consumes the `Frame`, returning its underlying I/O stream, the buffer - /// with unprocessed data, and the codec. - /// - /// Note that care should be taken to not tamper with the underlying stream - /// of data coming in as it may corrupt the stream of frames otherwise - /// being worked with. - pub fn into_parts(self) -> FramedParts { - FramedParts { - io: self.io, - codec: self.codec, - flags: self.flags, - read_buf: self.read_buf, - write_buf: self.write_buf, - err: self.err, - } - } -} - -impl Framed -where - T: AsyncWrite + Unpin, - U: Encoder, -{ - #[inline] - /// Serialize item and Write to the inner buffer - pub fn write( - &mut self, - item: ::Item, - ) -> Result<(), ::Error> { - let remaining = self.write_buf.capacity() - self.write_buf.len(); - if remaining < LW { - self.write_buf.reserve(HW - remaining); - } - - self.codec.encode(item, &mut self.write_buf)?; - Ok(()) - } - - #[inline] - /// Check if framed is able to write more data. - /// - /// `Framed` object considers ready if there is free space in write buffer. - pub fn is_write_ready(&self) -> bool { - self.write_buf.len() < HW - } - - /// Flush write buffer to underlying I/O stream. - pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll> { - log::trace!("flushing framed transport"); - - let len = self.write_buf.len(); - if len != 0 { - let mut written = 0; - while written < len { - match Pin::new(&mut self.io).poll_write(cx, &self.write_buf[written..]) { - Poll::Pending => break, - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!( - "Disconnected during flush, written {}", - written - ); - self.flags.insert(Flags::DISCONNECTED); - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))); - } else { - written += n - } - } - Poll::Ready(Err(e)) => { - log::trace!("Error during flush: {}", e); - self.flags.insert(Flags::DISCONNECTED); - return Poll::Ready(Err(e)); - } - } - } - log::trace!("flushed {} bytes", written); - - // remove written data - if written == len { - self.write_buf.clear() - } else { - self.write_buf.advance(written); - } - } - - // flush - ready!(Pin::new(&mut self.io).poll_flush(cx))?; - - if self.write_buf.is_empty() { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } -} - -impl Framed -where - T: AsyncRead + AsyncWrite + Unpin, -{ - #[inline] - /// Flush write buffer and shutdown underlying I/O stream. - /// - /// Close method shutdown write side of a io object and - /// then reads until disconnect or error, high level code must use - /// timeout for close operation. - pub fn close(&mut self, cx: &mut Context<'_>) -> Poll> { - if !self.flags.contains(Flags::DISCONNECTED) { - // flush write buffer - ready!(Pin::new(&mut self.io).poll_flush(cx))?; - - if !self.flags.contains(Flags::SHUTDOWN) { - // shutdown WRITE side - ready!(Pin::new(&mut self.io).poll_shutdown(cx)).map_err(|e| { - self.flags.insert(Flags::DISCONNECTED); - e - })?; - self.flags.insert(Flags::SHUTDOWN); - } - - // read until 0 or err - let mut buf = [0u8; 512]; - loop { - let mut read_buf = tokio::io::ReadBuf::new(&mut buf); - match ready!(Pin::new(&mut self.io).poll_read(cx, &mut read_buf)) { - Err(_) | Ok(_) if read_buf.filled().is_empty() => { - break; - } - _ => (), - } - } - self.flags.insert(Flags::DISCONNECTED); - } - log::trace!("framed transport flushed and closed"); - Poll::Ready(Ok(())) - } -} - -pub type ItemType = - Result<::Item, Either<::Error, io::Error>>; - -impl Framed -where - T: AsyncRead + Unpin, - U: Decoder, -{ - /// Try to read underlying I/O stream and decode item. - pub fn next_item(&mut self, cx: &mut Context<'_>) -> Poll>> { - let mut done_read = false; - - loop { - // Repeatedly call `decode` or `decode_eof` as long as it is - // "readable". Readable is defined as not having returned `None`. If - // the upstream has returned EOF, and the decoder is no longer - // readable, it can be assumed that the decoder will never become - // readable again, at which point the stream is terminated. - - if self.flags.contains(Flags::READABLE) { - if self.flags.contains(Flags::EOF) { - return match self.codec.decode_eof(&mut self.read_buf) { - Ok(Some(frame)) => Poll::Ready(Some(Ok(frame))), - Ok(None) => { - if let Some(err) = self.err.take() { - Poll::Ready(Some(Err(Either::Right(err)))) - } else if !self.read_buf.is_empty() { - Poll::Ready(Some(Err(Either::Right(io::Error::new( - io::ErrorKind::Other, - "bytes remaining on stream", - ))))) - } else { - Poll::Ready(None) - } - } - Err(e) => return Poll::Ready(Some(Err(Either::Left(e)))), - }; - } - - log::trace!("attempting to decode a frame"); - - match self.codec.decode(&mut self.read_buf) { - Ok(Some(frame)) => { - log::trace!("frame decoded from buffer"); - return Poll::Ready(Some(Ok(frame))); - } - Err(e) => return Poll::Ready(Some(Err(Either::Left(e)))), - _ => (), // Need more data - } - - self.flags.remove(Flags::READABLE); - if done_read { - return Poll::Pending; - } - } - - debug_assert!(!self.flags.contains(Flags::EOF)); - - // read all data from socket - let mut updated = false; - loop { - // Otherwise, try to read more data and try again. Make sure we've got room - let remaining = self.read_buf.capacity() - self.read_buf.len(); - if remaining < LW { - self.read_buf.reserve(HW - remaining) - } - match crate::poll_read_buf( - Pin::new(&mut self.io), - cx, - &mut self.read_buf, - ) { - Poll::Pending => { - if updated { - done_read = true; - self.flags.insert(Flags::READABLE); - break; - } else { - return Poll::Pending; - } - } - Poll::Ready(Ok(n)) => { - if n == 0 { - self.flags.insert(Flags::EOF | Flags::READABLE); - if updated { - done_read = true; - } - break; - } else { - updated = true; - } - } - Poll::Ready(Err(e)) => { - if updated { - done_read = true; - self.err = Some(e); - self.flags.insert(Flags::EOF | Flags::READABLE); - break; - } else { - return Poll::Ready(Some(Err(Either::Right(e)))); - } - } - } - } - } - } -} - -impl Stream for Framed -where - T: AsyncRead + Unpin, - U: Decoder + Unpin, -{ - type Item = Result>; - - #[inline] - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.next_item(cx) - } -} - -impl Sink for Framed -where - T: AsyncRead + AsyncWrite + Unpin, - U: Encoder + Unpin, -{ - type Error = Either; - - #[inline] - fn poll_ready( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { - if self.is_write_ready() { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - - #[inline] - fn start_send( - mut self: Pin<&mut Self>, - item: ::Item, - ) -> Result<(), Self::Error> { - self.write(item).map_err(Either::Left) - } - - #[inline] - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.flush(cx).map_err(Either::Right) - } - - #[inline] - fn poll_close( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.close(cx).map_err(Either::Right) - } -} - -impl fmt::Debug for Framed -where - T: fmt::Debug, - U: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Framed") - .field("io", &self.io) - .field("codec", &self.codec) - .finish() - } -} - -/// `FramedParts` contains an export of the data of a Framed transport. -/// It can be used to construct a new `Framed` with a different codec. -/// It contains all current buffers and the inner transport. -#[derive(Debug)] -pub struct FramedParts { - /// The inner transport used to read bytes to and write bytes to - pub io: T, - - /// The codec - pub codec: U, - - /// The buffer with read but unprocessed data. - pub read_buf: BytesMut, - - /// A buffer with unprocessed data which are not written yet. - pub write_buf: BytesMut, - - flags: Flags, - err: Option, -} - -impl FramedParts { - /// Create a new, default, `FramedParts` - pub fn new(io: T, codec: U) -> FramedParts { - FramedParts { - io, - codec, - err: None, - flags: Flags::empty(), - read_buf: BytesMut::new(), - write_buf: BytesMut::new(), - } - } - - /// Create a new `FramedParts` with read buffer - pub fn with_read_buf(io: T, codec: U, read_buf: BytesMut) -> FramedParts { - FramedParts { - io, - codec, - read_buf, - err: None, - flags: Flags::empty(), - write_buf: BytesMut::new(), - } - } -} - -#[cfg(test)] -mod tests { - use futures::{future::lazy, Sink}; - use ntex::testing::Io; - use ntex_bytes::Bytes; - - use super::*; - use crate::BytesCodec; - - #[ntex::test] - async fn test_basics() { - let (_, server) = Io::create(); - let mut server = Framed::new(server, BytesCodec); - server.get_codec_mut(); - server.get_ref(); - server.get_mut(); - - let parts = server.into_parts(); - let server = Framed::from_parts(FramedParts::new(parts.io, parts.codec)); - assert!(format!("{:?}", server).contains("Framed")); - } - - #[ntex::test] - async fn test_sink() { - let (client, server) = Io::create(); - client.remote_buffer_cap(1024); - let mut server = Framed::new(server, BytesCodec); - - assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx)) - .await - .is_ready()); - - let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"); - Pin::new(&mut server).start_send(data).unwrap(); - assert_eq!(client.read_any(), b"".as_ref()); - assert_eq!(server.read_buf(), b"".as_ref()); - assert_eq!(server.write_buf(), b"GET /test HTTP/1.1\r\n\r\n".as_ref()); - - assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx)) - .await - .is_ready()); - assert_eq!(client.read_any(), b"GET /test HTTP/1.1\r\n\r\n".as_ref()); - - assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx)) - .await - .is_pending()); - client.close().await; - assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx)) - .await - .is_ready()); - assert!(client.is_closed()); - } - - #[ntex::test] - async fn test_write_pending() { - let (client, server) = Io::create(); - let mut server = Framed::new(server, BytesCodec); - - assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx)) - .await - .is_ready()); - let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"); - Pin::new(&mut server).start_send(data).unwrap(); - - client.remote_buffer_cap(3); - assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx)) - .await - .is_pending()); - assert_eq!(client.read_any(), b"GET".as_ref()); - - client.remote_buffer_cap(1024); - assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx)) - .await - .is_ready()); - assert_eq!(client.read_any(), b" /test HTTP/1.1\r\n\r\n".as_ref()); - - assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx)) - .await - .is_pending()); - client.close().await; - assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx)) - .await - .is_ready()); - assert!(client.is_closed()); - assert!(server.is_closed()); - } - - #[ntex::test] - async fn test_read_pending() { - let (client, server) = Io::create(); - let mut server = Framed::new(server, BytesCodec); - - client.read_pending(); - assert!(lazy(|cx| Pin::new(&mut server).next_item(cx)) - .await - .is_pending()); - - client.write(b"GET /test HTTP/1.1\r\n\r\n"); - client.close().await; - - let item = lazy(|cx| Pin::new(&mut server).next_item(cx)) - .await - .map(|i| i.unwrap().unwrap().freeze()); - assert_eq!( - item, - Poll::Ready(Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n")) - ); - let item = lazy(|cx| Pin::new(&mut server).next_item(cx)) - .await - .map(|i| i.is_none()); - assert_eq!(item, Poll::Ready(true)); - } - - #[ntex::test] - async fn test_read_error() { - let (client, server) = Io::create(); - let mut server = Framed::new(server, BytesCodec); - - client.read_pending(); - assert!(lazy(|cx| Pin::new(&mut server).next_item(cx)) - .await - .is_pending()); - - client.write(b"GET /test HTTP/1.1\r\n\r\n"); - client.read_error(io::Error::new(io::ErrorKind::Other, "error")); - - let item = lazy(|cx| Pin::new(&mut server).next_item(cx)) - .await - .map(|i| i.unwrap().unwrap().freeze()); - assert_eq!( - item, - Poll::Ready(Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n")) - ); - assert_eq!( - lazy(|cx| Pin::new(&mut server).next_item(cx)) - .await - .map(|i| i.unwrap().is_err()), - Poll::Ready(true) - ); - } -} diff --git a/ntex-codec/src/lib.rs b/ntex-codec/src/lib.rs index 779bb7e2..b4a2feb5 100644 --- a/ntex-codec/src/lib.rs +++ b/ntex-codec/src/lib.rs @@ -1,55 +1,10 @@ -//! Utilities for encoding and decoding frames. -//! -//! Contains adapters to go from streams of bytes, [`AsyncRead`] and -//! [`AsyncWrite`], to framed streams implementing `Sink` and `Stream`. -//! Framed streams are also known as `transports`. -//! -//! [`AsyncRead`]: # -//! [`AsyncWrite`]: # #![deny(rust_2018_idioms, warnings)] -use std::{io, mem::MaybeUninit, pin::Pin, task::Context, task::Poll}; +//! Utilities for encoding and decoding frames. mod bcodec; mod decoder; mod encoder; -mod framed; pub use self::bcodec::BytesCodec; pub use self::decoder::Decoder; pub use self::encoder::Encoder; -pub use self::framed::{Framed, FramedParts}; - -pub use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -use ntex_bytes::{BufMut, BytesMut}; - -pub fn poll_read_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut BytesMut, -) -> Poll> { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - - let n = { - let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [MaybeUninit]) }; - let mut buf = ReadBuf::uninit(dst); - let ptr = buf.filled().as_ptr(); - if io.poll_read(cx, &mut buf)?.is_pending() { - return Poll::Pending; - } - - // Ensure the pointer does not change from under us - assert_eq!(ptr, buf.filled().as_ptr()); - buf.filled().len() - }; - - // Safety: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by `ReadBuf::filled`. - unsafe { - buf.advance_mut(n); - } - - Poll::Ready(Ok(n)) -} diff --git a/ntex-io/src/dispatcher.rs b/ntex-io/src/dispatcher.rs index 66fd9d74..37996c00 100644 --- a/ntex-io/src/dispatcher.rs +++ b/ntex-io/src/dispatcher.rs @@ -1,6 +1,6 @@ //! Framed transport dispatcher use std::{ - cell::Cell, future::Future, io, pin::Pin, rc::Rc, task::Context, task::Poll, time, + cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll, time, }; use ntex_bytes::Pool; @@ -70,7 +70,6 @@ enum DispatcherError { KeepAlive, Encoder(U), Service(S), - Io(io::Error), } enum PollService { @@ -157,7 +156,7 @@ where /// /// By default disconnect timeout is set to 1 seconds. pub fn disconnect_timeout(self, val: Seconds) -> Self { - self.inner.state.set_disconnect_timeout(val); + self.inner.state.set_disconnect_timeout(val.into()); self } } @@ -176,12 +175,7 @@ where Ok(Some(val)) => match write.encode(val, &self.codec) { Ok(true) => (), Ok(false) => write.enable_backpressure(None), - Err(Either::Left(err)) => { - self.error.set(Some(DispatcherError::Encoder(err))) - } - Err(Either::Right(err)) => { - self.error.set(Some(DispatcherError::Io(err))) - } + Err(err) => self.error.set(Some(DispatcherError::Encoder(err))), }, Err(err) => self.error.set(Some(DispatcherError::Service(err))), Ok(None) => return, @@ -407,12 +401,7 @@ where Ok(Some(item)) => match write.encode(item, &self.shared.codec) { Ok(true) => (), Ok(false) => write.enable_backpressure(None), - Err(Either::Left(err)) => { - self.shared.error.set(Some(DispatcherError::Encoder(err))) - } - Err(Either::Right(err)) => { - self.shared.error.set(Some(DispatcherError::Io(err))) - } + Err(err) => self.shared.error.set(Some(DispatcherError::Encoder(err))), }, Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))), Ok(None) => (), @@ -443,9 +432,6 @@ where DispatcherError::Encoder(err) => { PollService::Item(DispatchItem::EncoderError(err)) } - DispatcherError::Io(err) => { - PollService::Item(DispatchItem::Disconnect(Some(err))) - } DispatcherError::Service(err) => { self.error.set(Some(err)); PollService::ServiceError diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs index 961dcfc8..ee2cf5c8 100644 --- a/ntex-io/src/filter.rs +++ b/ntex-io/src/filter.rs @@ -46,13 +46,7 @@ impl ReadFilter for DefaultFilter { #[inline] fn read_closed(&self, err: Option) { - if err.is_some() { - self.0.error.set(err); - } - self.0.write_task.wake(); - self.0.dispatch_task.wake(); - self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP); - self.0.notify_disconnect(); + self.0.set_error(err); } #[inline] @@ -109,13 +103,9 @@ impl WriteFilter for DefaultFilter { #[inline] fn write_closed(&self, err: Option) { - if err.is_some() { - self.0.error.set(err); - } - self.0.read_task.wake(); + self.0.set_error(err); + self.0.insert_flags(Flags::IO_CLOSED); self.0.dispatch_task.wake(); - self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP); - self.0.notify_disconnect(); } #[inline] diff --git a/ntex-io/src/lib.rs b/ntex-io/src/lib.rs index b25a04af..b9e01a61 100644 --- a/ntex-io/src/lib.rs +++ b/ntex-io/src/lib.rs @@ -1,4 +1,4 @@ -use std::{fmt, future::Future, io, task::Context, task::Poll}; +use std::{any::Any, any::TypeId, fmt, future::Future, io, task::Context, task::Poll}; pub mod testing; @@ -14,12 +14,12 @@ mod tokio_impl; use ntex_bytes::BytesMut; use ntex_codec::{Decoder, Encoder}; -use ntex_util::time::Millis; +use ntex_util::{channel::oneshot::Receiver, future::Either, time::Millis}; pub use self::dispatcher::Dispatcher; pub use self::filter::DefaultFilter; -pub use self::state::{Io, IoRef, ReadRef, WriteRef}; -pub use self::tasks::{ReadState, WriteState}; +pub use self::state::{Io, IoRef, OnDisconnect, ReadRef, WriteRef}; +pub use self::tasks::{ReadContext, WriteContext}; pub use self::time::Timer; pub use self::utils::{filter_factory, from_iostream, into_boxed, into_io}; @@ -55,8 +55,15 @@ pub trait WriteFilter { fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>; } -pub trait Filter: ReadFilter + WriteFilter { +pub trait Filter: ReadFilter + WriteFilter + 'static { fn shutdown(&self, st: &IoRef) -> Poll>; + + fn query( + &self, + id: TypeId, + ) -> Either>, Receiver>>> { + Either::Left(None) + } } pub trait FilterFactory: Sized { @@ -69,7 +76,7 @@ pub trait FilterFactory: Sized { } pub trait IoStream { - fn start(self, _: ReadState, _: WriteState); + fn start(self, _: ReadContext, _: WriteContext); } /// Framed transport item diff --git a/ntex-io/src/state.rs b/ntex-io/src/state.rs index 8e049789..f73f6081 100644 --- a/ntex-io/src/state.rs +++ b/ntex-io/src/state.rs @@ -1,14 +1,14 @@ use std::cell::{Cell, RefCell}; use std::task::{Context, Poll}; -use std::{future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc}; +use std::{fmt, future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc}; use ntex_bytes::{BytesMut, PoolId, PoolRef}; use ntex_codec::{Decoder, Encoder}; -use ntex_util::time::{Millis, Seconds}; +use ntex_util::time::Millis; use ntex_util::{future::poll_fn, future::Either, task::LocalWaker}; use super::filter::{DefaultFilter, NullFilter}; -use super::tasks::{ReadState, WriteState}; +use super::tasks::{ReadContext, WriteContext}; use super::{Filter, FilterFactory, IoStream}; bitflags::bitflags! { @@ -21,13 +21,15 @@ bitflags::bitflags! { const IO_FILTERS_TO = 0b0000_0000_0000_0100; /// shutdown io tasks const IO_SHUTDOWN = 0b0000_0000_0000_1000; + /// io object is closed + const IO_CLOSED = 0b0000_0000_0001_0000; /// pause io read - const RD_PAUSED = 0b0000_0000_0000_1000; + const RD_PAUSED = 0b0000_0000_0010_0000; /// new data is available - const RD_READY = 0b0000_0000_0001_0000; + const RD_READY = 0b0000_0000_0100_0000; /// read buffer is full - const RD_BUF_FULL = 0b0000_0000_0010_0000; + const RD_BUF_FULL = 0b0000_0000_1000_0000; /// wait write completion const WR_WAIT = 0b0000_0001_0000_0000; @@ -103,8 +105,22 @@ impl IoStateInner { } #[inline] - fn is_io_err(&self) -> bool { - self.flags.get().contains(Flags::IO_ERR) + fn is_io_open(&self) -> bool { + !self.flags.get().intersects( + Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_SHUTDOWN | Flags::IO_CLOSED, + ) + } + + #[inline] + pub(super) fn set_error(&self, err: Option) { + if err.is_some() { + self.error.set(err); + } + self.read_task.wake(); + self.write_task.wake(); + self.dispatch_task.wake(); + self.insert_flags(Flags::IO_ERR | Flags::DSP_STOP); + self.notify_disconnect(); } #[inline] @@ -195,7 +211,7 @@ impl Io { let io_ref = IoRef(inner); // start io tasks - io.start(ReadState(io_ref.clone()), WriteState(io_ref.clone())); + io.start(ReadContext(io_ref.clone()), WriteContext(io_ref.clone())); Io(io_ref, FilterItem::Ptr(Box::into_raw(filter))) } @@ -218,8 +234,8 @@ impl Io { #[inline] /// Set io disconnect timeout in secs - pub fn set_disconnect_timeout(&self, timeout: Seconds) { - self.0 .0.disconnect_timeout.set(timeout.into()); + pub fn set_disconnect_timeout(&self, timeout: Millis) { + self.0 .0.disconnect_timeout.set(timeout); } } @@ -242,31 +258,6 @@ impl Io { pub fn register_dispatcher(&self, cx: &mut Context<'_>) { self.0 .0.dispatch_task.register(cx.waker()); } - - #[inline] - /// Mark dispatcher as stopped - pub fn dispatcher_stopped(&self) { - self.0 .0.insert_flags(Flags::DSP_STOP); - } - - #[inline] - /// Gracefully shutdown read and write io tasks - pub fn init_shutdown(&self, cx: &mut Context<'_>) { - let flags = self.0 .0.flags.get(); - - if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) { - log::trace!("initiate io shutdown {:?}", flags); - self.0 .0.insert_flags(Flags::IO_FILTERS); - if let Err(err) = self.0 .0.shutdown_filters(&self.0) { - self.0 .0.error.set(Some(err)); - self.0 .0.insert_flags(Flags::IO_ERR); - } - - self.0 .0.read_task.wake(); - self.0 .0.write_task.wake(); - self.0 .0.dispatch_task.register(cx.waker()); - } - } } impl IoRef { @@ -284,9 +275,9 @@ impl IoRef { } #[inline] - /// Check if io error occured in read or write task - pub fn is_io_err(&self) -> bool { - self.0.is_io_err() + /// Check if io is still active + pub fn is_io_open(&self) -> bool { + self.0.is_io_open() } #[inline] @@ -304,10 +295,13 @@ impl IoRef { #[inline] /// Check if io stream is closed pub fn is_closed(&self) -> bool { - self.0 - .flags - .get() - .intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::DSP_STOP) + self.0.flags.get().intersects( + Flags::IO_ERR + | Flags::IO_SHUTDOWN + | Flags::IO_CLOSED + | Flags::IO_FILTERS + | Flags::DSP_STOP, + ) } #[inline] @@ -316,6 +310,12 @@ impl IoRef { self.0.error.take() } + #[inline] + /// Mark dispatcher as stopped + pub fn stop_dispatcher(&self) { + self.0.insert_flags(Flags::DSP_STOP); + } + #[inline] /// Reset keep-alive error pub fn reset_keepalive(&self) { @@ -360,9 +360,15 @@ impl IoRef { pub fn on_disconnect(&self) -> OnDisconnect { OnDisconnect::new(self.0.clone(), self.0.flags.get().contains(Flags::IO_ERR)) } + + #[inline] + /// Query specific data + pub fn query(&self) -> Option { + todo!() + } } -impl Io { +impl IoRef { #[inline] /// Read incoming io stream and decode codec item. pub async fn next( @@ -375,18 +381,18 @@ impl Io { let read = self.read(); loop { - let mut buf = self.0 .0.read_buf.take(); + let mut buf = self.0.read_buf.take(); let item = if let Some(ref mut buf) = buf { codec.decode(buf) } else { Ok(None) }; - self.0 .0.read_buf.set(buf); + self.0.read_buf.set(buf); return match item { Ok(Some(el)) => Ok(Some(el)), Ok(None) => { - self.0 .0.remove_flags(Flags::RD_READY); + self.0.remove_flags(Flags::RD_READY); if poll_fn(|cx| read.poll_ready(cx)) .await .map_err(Either::Right)? @@ -411,53 +417,53 @@ impl Io { where U: Encoder, { - let filter = self.0 .0.filter.get(); + let filter = self.0.filter.get(); let mut buf = filter .get_write_buf() - .unwrap_or_else(|| self.0 .0.pool.get().get_write_buf()); + .unwrap_or_else(|| self.0.pool.get().get_write_buf()); let is_write_sleep = buf.is_empty(); codec.encode(item, &mut buf).map_err(Either::Left)?; filter.release_write_buf(buf).map_err(Either::Right)?; - self.0 .0.insert_flags(Flags::WR_WAIT); + self.0.insert_flags(Flags::WR_WAIT); if is_write_sleep { - self.0 .0.write_task.wake(); + self.0.write_task.wake(); } - poll_fn(|cx| self.write().poll_flush(cx)) + poll_fn(|cx| self.write().poll_flush(cx, true)) .await .map_err(Either::Right)?; Ok(()) } #[inline] - /// Shuts down connection - pub async fn shutdown(&self) -> Result<(), io::Error> { - if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { - Ok(()) - } else { - poll_fn(|cx| { - let flags = self.flags(); - if !flags.contains(Flags::IO_FILTERS) { - self.init_shutdown(cx); - } + /// Shut down connection + pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll> { + let flags = self.flags(); - if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { - if let Some(err) = self.0 .0.error.take() { - Poll::Ready(Err(err)) - } else { - Poll::Ready(Ok(())) - } - } else { - self.0 .0.insert_flags(Flags::IO_FILTERS); - self.0 .0.dispatch_task.register(cx.waker()); - Poll::Pending - } - }) - .await + if flags.intersects(Flags::IO_ERR | Flags::IO_CLOSED) { + Poll::Ready(Ok(())) + } else { + if !flags.contains(Flags::IO_FILTERS) { + self.init_shutdown(cx); + } + self.0.insert_flags(Flags::IO_FILTERS); + + if let Some(err) = self.0.error.take() { + Poll::Ready(Err(err)) + } else { + self.0.dispatch_task.register(cx.waker()); + Poll::Pending + } } } + #[inline] + /// Shut down connection + pub async fn shutdown(&self) -> Result<(), io::Error> { + poll_fn(|cx| self.poll_shutdown(cx)).await + } + #[inline] #[allow(clippy::type_complexity)] pub fn poll_next( @@ -468,38 +474,48 @@ impl Io { where U: Decoder, { - if self - .read() - .poll_ready(cx) - .map_err(Either::Right)? - .is_ready() - { - let mut buf = self.0 .0.read_buf.take(); - let item = if let Some(ref mut buf) = buf { - codec.decode(buf) - } else { - Ok(None) - }; - self.0 .0.read_buf.set(buf); + let read = self.read(); - match item { - Ok(Some(el)) => Poll::Ready(Ok(Some(el))), - Ok(None) => { - if let Poll::Ready(res) = - self.read().poll_ready(cx).map_err(Either::Right)? - { - if res.is_none() { - return Poll::Ready(Ok(None)); - } + match read.decode(codec) { + Ok(Some(el)) => Poll::Ready(Ok(Some(el))), + Ok(None) => { + if let Poll::Ready(res) = read.poll_ready(cx).map_err(Either::Right)? { + if res.is_none() { + return Poll::Ready(Ok(None)); } - Poll::Pending } - Err(err) => Poll::Ready(Err(Either::Left(err))), + Poll::Pending } - } else { - Poll::Pending + Err(err) => Poll::Ready(Err(Either::Left(err))), } } + + #[inline] + /// Gracefully shutdown read and write io tasks + pub(super) fn init_shutdown(&self, cx: &mut Context<'_>) { + let flags = self.0.flags.get(); + + if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) { + log::trace!("initiate io shutdown {:?}", flags); + self.0.insert_flags(Flags::IO_FILTERS); + if let Err(err) = self.0.shutdown_filters(self) { + self.0.error.set(Some(err)); + self.0.insert_flags(Flags::IO_ERR); + } + + self.0.read_task.wake(); + self.0.write_task.wake(); + self.0.dispatch_task.register(cx.waker()); + } + } +} + +impl fmt::Debug for IoRef { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IoRef") + .field("open", &!self.is_closed()) + .finish() + } } impl Io { @@ -576,7 +592,10 @@ impl Io { impl Drop for Io { fn drop(&mut self) { - log::trace!("stopping io stream"); + log::trace!( + "io is dropped, force stopping io streams {:?}", + self.0.flags() + ); if let FilterItem::Ptr(p) = self.1 { if p.is_null() { return; @@ -635,7 +654,7 @@ impl<'a> WriteRef<'a> { /// /// Write task must be waken up separately. pub fn enable_backpressure(&self, cx: Option<&mut Context<'_>>) { - log::trace!("enable write back-pressure"); + log::trace!("enable write back-pressure {:?}", cx.is_some()); self.0.insert_flags(Flags::WR_BACKPRESSURE); if let Some(cx) = cx { self.0.dispatch_task.register(cx.waker()); @@ -669,7 +688,7 @@ impl<'a> WriteRef<'a> { &self, item: U::Item, codec: &U, - ) -> Result::Error, io::Error>> + ) -> Result::Error> where U: Encoder, { @@ -690,28 +709,44 @@ impl<'a> WriteRef<'a> { } // encode item and wake write task - let result = codec - .encode(item, &mut buf) - .map(|_| { - if is_write_sleep { - self.0.write_task.wake(); - } - buf.len() < hw - }) - .map_err(Either::Left); - filter.release_write_buf(buf).map_err(Either::Right)?; - Ok(result?) + let result = codec.encode(item, &mut buf).map(|_| { + if is_write_sleep { + self.0.write_task.wake(); + } + buf.len() < hw + }); + if let Err(err) = filter.release_write_buf(buf) { + self.0.set_error(Some(err)); + } + result } else { Ok(true) } } #[inline] - /// Wake write task and instruct to write all data. + /// Wake write task and instruct to write data. /// - /// When write task is done wake dispatcher. - pub fn poll_flush(&self, cx: &mut Context<'_>) -> Poll> { - self.0.insert_flags(Flags::WR_WAIT); + /// If full is true then wake up dispatcher when all data is flushed + /// otherwise wake up when size of write buffer is lower than + /// buffer max size. + pub fn poll_flush( + &self, + cx: &mut Context<'_>, + full: bool, + ) -> Poll> { + // check io error + if !self.0.is_io_open() { + return Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| { + io::Error::new(io::ErrorKind::Other, "disconnected") + }))); + } + + if full { + self.0.insert_flags(Flags::WR_WAIT); + } else { + self.0.insert_flags(Flags::WR_BACKPRESSURE); + } if let Some(buf) = self.0.write_buf.take() { if !buf.is_empty() { @@ -722,14 +757,16 @@ impl<'a> WriteRef<'a> { } } - if self.0.is_io_err() { - Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| { - io::Error::new(io::ErrorKind::Other, "disconnected") - }))) - } else { - self.0.dispatch_task.register(cx.waker()); - Poll::Ready(Ok(())) - } + // self.0.dispatch_task.register(cx.waker()); + Poll::Ready(Ok(())) + } + + #[inline] + /// Wake write task and instruct to write data. + /// + /// This is async version of .poll_flush() method. + pub async fn flush(&self, full: bool) -> Result<(), io::Error> { + poll_fn(|cx| self.poll_flush(cx, full)).await } } @@ -834,7 +871,7 @@ impl<'a> ReadRef<'a> { let mut flags = self.0.flags.get(); let ready = flags.contains(Flags::RD_READY); - if self.0.is_io_err() { + if !self.0.is_io_open() { if let Some(err) = self.0.error.take() { Poll::Ready(Err(err)) } else { @@ -843,7 +880,6 @@ impl<'a> ReadRef<'a> { } else if ready { Poll::Ready(Ok(Some(()))) } else { - flags.remove(Flags::RD_READY); if flags.contains(Flags::RD_BUF_FULL) { log::trace!("read back-pressure is enabled, wake io task"); flags.remove(Flags::RD_BUF_FULL); @@ -939,7 +975,6 @@ mod tests { #[ntex::test] async fn utils() { - env_logger::init(); let (client, server) = IoTest::create(); client.remote_buffer_cap(1024); client.write(TEXT); @@ -1041,7 +1076,7 @@ mod tests { in_bytes: Rc>, out_bytes: Rc>, } - impl Filter for Counter { + impl Filter for Counter { fn shutdown(&self, _: &IoRef) -> Poll> { Poll::Ready(Ok(())) } diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index 4a956a3f..0b07f90d 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -4,9 +4,9 @@ use ntex_bytes::{BytesMut, PoolRef}; use super::{state::Flags, IoRef, WriteReadiness}; -pub struct ReadState(pub(super) IoRef); +pub struct ReadContext(pub(super) IoRef); -impl ReadState { +impl ReadContext { #[inline] pub fn memory_pool(&self) -> PoolRef { self.0 .0.pool.get() @@ -60,9 +60,9 @@ impl ReadState { } } -pub struct WriteState(pub(super) IoRef); +pub struct WriteContext(pub(super) IoRef); -impl WriteState { +impl WriteContext { #[inline] pub fn memory_pool(&self) -> PoolRef { self.0 .0.pool.get() diff --git a/ntex-io/src/testing.rs b/ntex-io/src/testing.rs index 173fe516..1277d0d3 100644 --- a/ntex-io/src/testing.rs +++ b/ntex-io/src/testing.rs @@ -1,11 +1,13 @@ use std::cell::{Cell, RefCell}; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll, Waker}; -use std::{cmp, fmt, io, mem}; +use std::{cmp, fmt, future::Future, io, mem, pin::Pin, rc::Rc}; -use ntex_bytes::{BufMut, BytesMut}; +use ntex_bytes::{Buf, BufMut, BytesMut}; use ntex_util::future::poll_fn; -use ntex_util::time::{sleep, Millis}; +use ntex_util::time::{sleep, Millis, Sleep}; + +use crate::{IoStream, ReadContext, WriteContext, WriteReadiness}; #[derive(Default)] struct AtomicWaker(Arc>>>); @@ -441,138 +443,181 @@ mod tokio { } } -#[cfg(not(feature = "tokio"))] -mod non_tokio { - impl IoStream for IoTest { - fn start(self, read: ReadState, write: WriteState) { - let io = Rc::new(self); +impl IoStream for IoTest { + fn start(self, read: ReadContext, write: WriteContext) { + let io = Rc::new(self); - ntex_util::spawn(ReadTask { - io: io.clone(), - state: read, - }); - ntex_util::spawn(WriteTask { - io, - state: write, - st: IoWriteState::Processing, - }); - } + ntex_util::spawn(ReadTask { + io: io.clone(), + state: read, + }); + ntex_util::spawn(WriteTask { + io, + state: write, + st: IoWriteState::Processing(None), + }); } +} - /// Read io task - struct ReadTask { - io: Rc, - state: ReadState, - } +/// Read io task +struct ReadTask { + io: Rc, + state: ReadContext, +} - impl Future for ReadTask { - type Output = (); +impl Future for ReadTask { + type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_ref(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_ref(); - match this.state.poll_ready(cx) { - Poll::Ready(Err(())) => { - log::trace!("read task is instructed to terminate"); - Poll::Ready(()) - } - Poll::Ready(Ok(())) => { - let io = &this.io; - let pool = this.state.memory_pool(); - let mut buf = self.state.get_read_buf(); - let (hw, lw) = pool.read_params().unpack(); + match this.state.poll_ready(cx) { + Poll::Ready(Err(())) => { + log::trace!("read task is instructed to terminate"); + Poll::Ready(()) + } + Poll::Ready(Ok(())) => { + let io = &this.io; + let pool = this.state.memory_pool(); + let mut buf = self.state.get_read_buf(); + let (hw, lw) = pool.read_params().unpack(); - // read data from socket - let mut new_bytes = 0; - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - match io.poll_read_buf(cx, &mut buf) { - Poll::Pending => { - log::trace!("no more data in io stream"); - break; - } - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("io stream is disconnected"); - this.state.release_read_buf(buf, new_bytes); - this.state.close(None); - return Poll::Ready(()); - } else { - new_bytes += n; - if buf.len() > hw { - break; - } - } - } - Poll::Ready(Err(err)) => { - log::trace!("read task failed on io {:?}", err); - this.state.release_read_buf(buf, new_bytes); - this.state.close(Some(err)); - return Poll::Ready(()); - } - } + // read data from socket + let mut new_bytes = 0; + loop { + // make sure we've got room + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); } - this.state.release_read_buf(buf, new_bytes); - Poll::Pending - } - Poll::Pending => Poll::Pending, - } - } - } - - #[derive(Debug)] - enum IoWriteState { - Processing, - Shutdown(Option, Shutdown), - } - - #[derive(Debug)] - enum Shutdown { - None, - Flushed, - Stopping, - } - - /// Write io task - struct WriteTask { - st: IoWriteState, - io: Rc, - state: WriteState, - } - - impl Future for WriteTask { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut().get_mut(); - - match this.st { - IoWriteState::Processing => { - match this.state.poll_ready(cx) { - Poll::Ready(Ok(())) => { - // flush framed instance - match flush_io(&this.io, &this.state, cx) { - Poll::Pending | Poll::Ready(true) => Poll::Pending, - Poll::Ready(false) => Poll::Ready(()), + match io.poll_read_buf(cx, &mut buf) { + Poll::Pending => { + log::trace!("no more data in io stream"); + break; + } + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("io stream is disconnected"); + let _ = this.state.release_read_buf(buf, new_bytes); + this.state.close(None); + return Poll::Ready(()); + } else { + new_bytes += n; + if buf.len() > hw { + break; + } } } - Poll::Ready(Err(WriteReadiness::Shutdown)) => { - log::trace!("write task is instructed to shutdown"); - - this.st = IoWriteState::Shutdown( - this.state.disconnect_timeout().map(sleep), - Shutdown::None, - ); - self.poll(cx) + Poll::Ready(Err(err)) => { + log::trace!("read task failed on io {:?}", err); + let _ = this.state.release_read_buf(buf, new_bytes); + this.state.close(Some(err)); + return Poll::Ready(()); } - Poll::Ready(Err(WriteReadiness::Terminate)) => { - log::trace!("write task is instructed to terminate"); + } + } + + let _ = this.state.release_read_buf(buf, new_bytes); + Poll::Pending + } + Poll::Pending => Poll::Pending, + } + } +} + +#[derive(Debug)] +enum IoWriteState { + Processing(Option), + Shutdown(Option, Shutdown), +} + +#[derive(Debug)] +enum Shutdown { + None, + Flushed, + Stopping, +} + +/// Write io task +struct WriteTask { + st: IoWriteState, + io: Rc, + state: WriteContext, +} + +impl Future for WriteTask { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.as_mut().get_mut(); + + match this.st { + IoWriteState::Processing(ref mut delay) => { + match this.state.poll_ready(cx) { + Poll::Ready(Ok(())) => { + // flush framed instance + match flush_io(&this.io, &this.state, cx) { + Poll::Pending | Poll::Ready(true) => Poll::Pending, + Poll::Ready(false) => Poll::Ready(()), + } + } + Poll::Ready(Err(WriteReadiness::Timeout(time))) => { + if delay.is_none() { + *delay = Some(sleep(time)); + } + self.poll(cx) + } + Poll::Ready(Err(WriteReadiness::Shutdown(time))) => { + log::trace!("write task is instructed to shutdown"); + + let timeout = if let Some(delay) = delay.take() { + delay + } else { + sleep(time) + }; + + this.st = IoWriteState::Shutdown(Some(timeout), Shutdown::None); + self.poll(cx) + } + Poll::Ready(Err(WriteReadiness::Terminate)) => { + log::trace!("write task is instructed to terminate"); + // shutdown WRITE side + this.io + .local + .lock() + .unwrap() + .borrow_mut() + .flags + .insert(Flags::CLOSED); + this.state.close(None); + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + } + } + IoWriteState::Shutdown(ref mut delay, ref mut st) => { + // close WRITE side and wait for disconnect on read side. + // use disconnect timeout, otherwise it could hang forever. + loop { + match st { + Shutdown::None => { + // flush write buffer + match flush_io(&this.io, &this.state, cx) { + Poll::Ready(true) => { + *st = Shutdown::Flushed; + continue; + } + Poll::Ready(false) => { + log::trace!( + "write task is closed with err during flush" + ); + return Poll::Ready(()); + } + _ => (), + } + } + Shutdown::Flushed => { // shutdown WRITE side this.io .local @@ -581,143 +626,102 @@ mod non_tokio { .borrow_mut() .flags .insert(Flags::CLOSED); - this.state.close(None); - Poll::Ready(()) + *st = Shutdown::Stopping; + continue; } - Poll::Pending => Poll::Pending, - } - } - IoWriteState::Shutdown(ref mut delay, ref mut st) => { - // close WRITE side and wait for disconnect on read side. - // use disconnect timeout, otherwise it could hang forever. - loop { - match st { - Shutdown::None => { - // flush write buffer - match flush_io(&this.io, &this.state, cx) { - Poll::Ready(true) => { - *st = Shutdown::Flushed; - continue; - } - Poll::Ready(false) => { - log::trace!( - "write task is closed with err during flush" - ); + Shutdown::Stopping => { + // read until 0 or err + let io = &this.io; + loop { + let mut buf = BytesMut::new(); + match io.poll_read_buf(cx, &mut buf) { + Poll::Ready(Err(e)) => { + this.state.close(Some(e)); + log::trace!("write task is stopped"); return Poll::Ready(()); } + Poll::Ready(Ok(n)) if n == 0 => { + this.state.close(None); + log::trace!("write task is stopped"); + return Poll::Ready(()); + } + Poll::Pending => break, _ => (), } } - Shutdown::Flushed => { - // shutdown WRITE side - this.io - .local - .lock() - .unwrap() - .borrow_mut() - .flags - .insert(Flags::CLOSED); - *st = Shutdown::Stopping; - continue; - } - Shutdown::Stopping => { - // read until 0 or err - let io = &this.io; - loop { - let mut buf = BytesMut::new(); - match io.poll_read_buf(cx, &mut buf) { - Poll::Ready(Err(e)) => { - this.state.close(Some(e)); - log::trace!("write task is stopped"); - return Poll::Ready(()); - } - Poll::Ready(Ok(n)) if n == 0 => { - this.state.close(None); - log::trace!("write task is stopped"); - return Poll::Ready(()); - } - Poll::Pending => break, - _ => (), - } - } - } } - - // disconnect timeout - if let Some(ref delay) = delay { - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } - } - log::trace!("write task is stopped after delay"); - this.state.close(None); - return Poll::Ready(()); } + + // disconnect timeout + if let Some(ref delay) = delay { + if delay.poll_elapsed(cx).is_pending() { + return Poll::Pending; + } + } + log::trace!("write task is stopped after delay"); + this.state.close(None); + return Poll::Ready(()); } } } } +} - /// Flush write buffer to underlying I/O stream. - pub(super) fn flush_io( - io: &IoTest, - state: &WriteState, - cx: &mut Context<'_>, - ) -> Poll { - let mut buf = if let Some(buf) = state.get_write_buf() { - buf - } else { - return Poll::Ready(true); - }; - let len = buf.len(); - let pool = state.memory_pool(); +/// Flush write buffer to underlying I/O stream. +pub(super) fn flush_io( + io: &IoTest, + state: &WriteContext, + cx: &mut Context<'_>, +) -> Poll { + let mut buf = if let Some(buf) = state.get_write_buf() { + buf + } else { + return Poll::Ready(true); + }; + let len = buf.len(); - if len != 0 { - log::trace!("flushing framed transport: {}", len); + if len != 0 { + log::trace!("flushing framed transport: {}", len); - let mut written = 0; - while written < len { - match io.poll_write_buf(cx, &buf[written..]) { - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!( - "disconnected during flush, written {}", - written - ); - pool.release_write_buf(buf); - state.close(Some(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))); - return Poll::Ready(false); - } else { - written += n - } - } - Poll::Pending => break, - Poll::Ready(Err(e)) => { - log::trace!("error during flush: {}", e); - pool.release_write_buf(buf); - state.close(Some(e)); + let mut written = 0; + while written < len { + match io.poll_write_buf(cx, &buf[written..]) { + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("disconnected during flush, written {}", written); + let _ = state.release_write_buf(buf); + state.close(Some(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + ))); return Poll::Ready(false); + } else { + written += n } } + Poll::Pending => break, + Poll::Ready(Err(e)) => { + log::trace!("error during flush: {}", e); + let _ = state.release_write_buf(buf); + state.close(Some(e)); + return Poll::Ready(false); + } } - log::trace!("flushed {} bytes", written); - - // remove written data - if written == len { - buf.clear(); - state.release_write_buf(buf); - Poll::Ready(true) - } else { - buf.advance(written); - state.release_write_buf(buf); - Poll::Pending - } - } else { - Poll::Ready(true) } + log::trace!("flushed {} bytes", written); + + // remove written data + if written == len { + buf.clear(); + let _ = state.release_write_buf(buf); + Poll::Ready(true) + } else { + buf.advance(written); + let _ = state.release_write_buf(buf); + Poll::Pending + } + } else { + Poll::Ready(true) } } diff --git a/ntex-io/src/tokio_impl.rs b/ntex-io/src/tokio_impl.rs index 81d829ac..560b90a1 100644 --- a/ntex-io/src/tokio_impl.rs +++ b/ntex-io/src/tokio_impl.rs @@ -3,15 +3,12 @@ use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc}; use ntex_bytes::{Buf, BufMut}; use ntex_util::time::{sleep, Sleep}; -use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf}; +use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf, net::TcpStream}; -use super::{IoStream, ReadState, WriteReadiness, WriteState}; +use super::{IoStream, ReadContext, WriteContext, WriteReadiness}; -impl IoStream for T -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn start(self, read: ReadState, write: WriteState) { +impl IoStream for TcpStream { + fn start(self, read: ReadContext, write: WriteContext) { let io = Rc::new(RefCell::new(self)); ntex_util::spawn(ReadTask::new(io.clone(), read)); @@ -19,26 +16,29 @@ where } } -/// Read io task -struct ReadTask { - io: Rc>, - state: ReadState, +#[cfg(unix)] +impl IoStream for tok_io::net::UnixStream { + fn start(self, _read: ReadContext, _write: WriteContext) { + let _io = Rc::new(RefCell::new(self)); + + todo!() + } } -impl ReadTask -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +/// Read io task +struct ReadTask { + io: Rc>, + state: ReadContext, +} + +impl ReadTask { /// Create new read io task - fn new(io: Rc>, state: ReadState) -> Self { + fn new(io: Rc>, state: ReadContext) -> Self { Self { io, state } } } -impl Future for ReadTask -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl Future for ReadTask { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -119,18 +119,15 @@ enum Shutdown { } /// Write io task -struct WriteTask { +struct WriteTask { st: IoWriteState, - io: Rc>, - state: WriteState, + io: Rc>, + state: WriteContext, } -impl WriteTask -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl WriteTask { /// Create new write io task - fn new(io: Rc>, state: WriteState) -> Self { + fn new(io: Rc>, state: WriteContext) -> Self { Self { io, state, @@ -139,10 +136,7 @@ where } } -impl Future for WriteTask -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl Future for WriteTask { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -272,7 +266,7 @@ where /// Flush write buffer to underlying I/O stream. pub(super) fn flush_io( io: &mut T, - state: &WriteState, + state: &WriteContext, cx: &mut Context<'_>, ) -> Poll { let mut buf = if let Some(buf) = state.get_write_buf() { @@ -284,12 +278,14 @@ pub(super) fn flush_io( let pool = state.memory_pool(); if len != 0 { - // log::trace!("flushing framed transport: {:?}", buf); + //log::trace!("flushing framed transport: {:?}", buf); let mut written = 0; while written < len { match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { - Poll::Pending => break, + Poll::Pending => { + break; + } Poll::Ready(Ok(n)) => { if n == 0 { log::trace!("Disconnected during flush, written {}", written); @@ -311,7 +307,7 @@ pub(super) fn flush_io( } } } - // log::trace!("flushed {} bytes", written); + //log::trace!("flushed {} bytes", written); // remove written data let result = if written == len { diff --git a/ntex-openssl/Cargo.toml b/ntex-openssl/Cargo.toml index 2c11be8c..31b2132f 100644 --- a/ntex-openssl/Cargo.toml +++ b/ntex-openssl/Cargo.toml @@ -22,6 +22,6 @@ ntex-util = "0.1.2" openssl = "0.10.32" [dev-dependencies] -ntex = { version = "0.4.14", features = ["openssl"] } +ntex = { version = "0.5.0", features = ["openssl"] } futures = "0.3" env_logger = "0.9" diff --git a/ntex-openssl/src/lib.rs b/ntex-openssl/src/lib.rs index 0becd770..2ca27f89 100644 --- a/ntex-openssl/src/lib.rs +++ b/ntex-openssl/src/lib.rs @@ -10,6 +10,13 @@ use ntex_io::{ use ntex_util::{future::poll_fn, time, time::Millis}; use openssl::ssl::{self, SslStream}; +/// Selected alpn protocol +pub enum AlpnHttpProtocol { + Http1, + Http2, +} + +/// An implementation of SSL streams pub struct SslFilter { inner: RefCell>>, } @@ -191,7 +198,7 @@ impl SslAcceptor { /// Set handshake timeout. /// /// Default is set to 5 seconds. - pub fn timeout>(mut self, timeout: U) -> Self { + pub fn timeout>(&mut self, timeout: U) -> &mut Self { self.timeout = timeout.into(); self } @@ -209,7 +216,7 @@ impl Clone for SslAcceptor { impl FilterFactory for SslAcceptor { type Filter = SslFilter; - type Error = io::Error; + type Error = Box; type Future = Pin, Self::Error>>>>; fn create(self, st: Io) -> Self::Future { @@ -225,8 +232,7 @@ impl FilterFactory for SslAcceptor { read_buf: None, write_buf: None, }; - let ssl_stream = - ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?; + let ssl_stream = ssl::SslStream::new(ssl, inner)?; Ok(SslFilter { inner: RefCell::new(ssl_stream), @@ -234,9 +240,9 @@ impl FilterFactory for SslAcceptor { })?; poll_fn(|cx| { - let _ = st.write().poll_flush(cx)?; + let _ = st.write().poll_flush(cx, true)?; handle_result(st.filter().inner.borrow_mut().accept(), &st, cx) - .map_err(map_to_ioerr) + .map_err(Into::>::into) }) .await?; @@ -244,7 +250,7 @@ impl FilterFactory for SslAcceptor { }) .await .map_err(|_| { - io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout") + io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into() }) .and_then(|item| item) }) @@ -265,7 +271,7 @@ impl SslConnector { impl FilterFactory for SslConnector { type Filter = SslFilter; - type Error = io::Error; + type Error = Box; type Future = Pin, Self::Error>>>>; fn create(self, st: Io) -> Self::Future { @@ -277,8 +283,7 @@ impl FilterFactory for SslConnector { read_buf: None, write_buf: None, }; - let ssl_stream = - ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?; + let ssl_stream = ssl::SslStream::new(ssl, inner)?; Ok(SslFilter { inner: RefCell::new(ssl_stream), @@ -286,9 +291,9 @@ impl FilterFactory for SslConnector { })?; poll_fn(|cx| { - let _ = st.write().poll_flush(cx)?; + let _ = st.write().poll_flush(cx, true)?; handle_result(st.filter().inner.borrow_mut().connect(), &st, cx) - .map_err(map_to_ioerr) + .map_err(Into::>::into) }) .await?; diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 8cb36f2c..9c9fa3eb 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.4.14" +version = "0.5.0" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" diff --git a/ntex/src/connect/io.rs b/ntex/src/connect/io.rs deleted file mode 100644 index c949add3..00000000 --- a/ntex/src/connect/io.rs +++ /dev/null @@ -1,148 +0,0 @@ -use std::task::{Context, Poll}; -use std::{future::Future, pin::Pin}; - -use crate::io::Io; -use crate::service::{Service, ServiceFactory}; -use crate::util::{PoolId, PoolRef, Ready}; - -use super::service::ConnectServiceResponse; -use super::{Address, Connect, ConnectError, Connector}; - -pub struct IoConnector { - inner: Connector, - pool: PoolRef, -} - -impl IoConnector { - /// Construct new connect service with custom dns resolver - pub fn new() -> Self { - IoConnector { - inner: Connector::new(), - pool: PoolId::P0.pool_ref(), - } - } - - /// Set memory pool. - /// - /// Use specified memory pool for memory allocations. By default P0 - /// memory pool is used. - pub fn memory_pool(mut self, id: PoolId) -> Self { - self.pool = id.pool_ref(); - self - } -} - -impl IoConnector { - /// Resolve and connect to remote host - pub fn connect(&self, message: U) -> IoConnectServiceResponse - where - Connect: From, - { - IoConnectServiceResponse { - inner: self.inner.call(message.into()), - pool: self.pool, - } - } -} - -impl Default for IoConnector { - fn default() -> Self { - IoConnector::new() - } -} - -impl Clone for IoConnector { - fn clone(&self) -> Self { - IoConnector { - inner: self.inner.clone(), - pool: self.pool, - } - } -} - -impl ServiceFactory for IoConnector { - type Request = Connect; - type Response = Io; - type Error = ConnectError; - type Config = (); - type Service = IoConnector; - type InitError = (); - type Future = Ready; - - #[inline] - fn new_service(&self, _: ()) -> Self::Future { - Ready::Ok(self.clone()) - } -} - -impl Service for IoConnector { - type Request = Connect; - type Response = Io; - type Error = ConnectError; - type Future = IoConnectServiceResponse; - - #[inline] - fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - #[inline] - fn call(&self, req: Connect) -> Self::Future { - self.connect(req) - } -} - -#[doc(hidden)] -pub struct IoConnectServiceResponse { - inner: ConnectServiceResponse, - pool: PoolRef, -} - -impl Future for IoConnectServiceResponse { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match Pin::new(&mut self.inner).poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(stream)) => { - Poll::Ready(Ok(Io::with_memory_pool(stream, self.pool))) - } - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[crate::rt_test] - async fn test_connect() { - let server = crate::server::test_server(|| { - crate::service::fn_service(|_| async { Ok::<_, ()>(()) }) - }); - - let srv = IoConnector::default(); - let result = srv.connect("").await; - assert!(result.is_err()); - let result = srv.connect("localhost:99999").await; - assert!(result.is_err()); - - let srv = IoConnector::default(); - let result = srv.connect(format!("{}", server.addr())).await; - assert!(result.is_ok()); - - let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![ - format!("127.0.0.1:{}", server.addr().port() - 1) - .parse() - .unwrap(), - server.addr(), - ]); - let result = crate::connect::connect(msg).await; - assert!(result.is_ok()); - - let msg = Connect::new(server.addr()); - let result = crate::connect::connect(msg).await; - assert!(result.is_ok()); - } -} diff --git a/ntex/src/connect/mod.rs b/ntex/src/connect/mod.rs index 37344ca2..7e329614 100644 --- a/ntex/src/connect/mod.rs +++ b/ntex/src/connect/mod.rs @@ -2,7 +2,6 @@ use std::future::Future; mod error; -mod io; mod message; mod resolve; mod service; @@ -13,19 +12,18 @@ mod uri; #[cfg(feature = "openssl")] pub mod openssl; -#[cfg(feature = "rustls")] -pub mod rustls; - -use crate::rt::net::TcpStream; +//#[cfg(feature = "rustls")] +//pub mod rustls; pub use self::error::ConnectError; -pub use self::io::IoConnector; pub use self::message::{Address, Connect}; pub use self::resolve::Resolver; pub use self::service::Connector; +use crate::io::Io; + /// Resolve and connect to remote host -pub fn connect(message: U) -> impl Future> +pub fn connect(message: U) -> impl Future> where T: Address + 'static, Connect: From, diff --git a/ntex/src/connect/openssl.rs b/ntex/src/connect/openssl.rs index 7e7c1c39..15c7fd28 100644 --- a/ntex/src/connect/openssl.rs +++ b/ntex/src/connect/openssl.rs @@ -2,14 +2,12 @@ use std::{future::Future, io, pin::Pin, task::Context, task::Poll}; use ntex_openssl::{SslConnector as IoSslConnector, SslFilter}; pub use open_ssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; -pub use tokio_openssl::SslStream; use crate::io::{DefaultFilter, Io}; -use crate::rt::net::TcpStream; use crate::service::{Service, ServiceFactory}; use crate::util::Ready; -use super::{Address, Connect, ConnectError, Connector, IoConnector as BaseIoConnector}; +use super::{Address, Connect, ConnectError, Connector}; pub struct OpensslConnector { connector: Connector, @@ -27,103 +25,6 @@ impl OpensslConnector { } impl OpensslConnector { - /// Resolve and connect to remote host - pub fn connect( - &self, - message: U, - ) -> impl Future, ConnectError>> - where - Connect: From, - { - let message = Connect::from(message); - let host = message.host().to_string(); - let conn = self.connector.call(message); - let openssl = self.openssl.clone(); - - async move { - let io = conn.await?; - trace!("SSL Handshake start for: {:?}", host); - - match openssl.configure() { - Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()), - Ok(config) => { - let config = config - .into_ssl(&host) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - let mut io = SslStream::new(config, io) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - match Pin::new(&mut io).connect().await { - Ok(_) => { - trace!("SSL Handshake success: {:?}", host); - Ok(io) - } - Err(e) => { - trace!("SSL Handshake error: {:?}", e); - Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)) - .into()) - } - } - } - } - } - } -} - -impl Clone for OpensslConnector { - fn clone(&self) -> Self { - OpensslConnector { - connector: self.connector.clone(), - openssl: self.openssl.clone(), - } - } -} - -impl ServiceFactory for OpensslConnector { - type Request = Connect; - type Response = SslStream; - type Error = ConnectError; - type Config = (); - type Service = OpensslConnector; - type InitError = (); - type Future = Ready; - - fn new_service(&self, _: ()) -> Self::Future { - Ready::Ok(self.clone()) - } -} - -impl Service for OpensslConnector { - type Request = Connect; - type Response = SslStream; - type Error = ConnectError; - type Future = Pin>>>; - - #[inline] - fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&self, req: Connect) -> Self::Future { - Box::pin(self.connect(req)) - } -} - -pub struct IoConnector { - connector: BaseIoConnector, - openssl: SslConnector, -} - -impl IoConnector { - /// Construct new OpensslConnectService factory - pub fn new(connector: SslConnector) -> Self { - IoConnector { - connector: BaseIoConnector::default(), - openssl: connector, - } - } -} - -impl IoConnector { /// Resolve and connect to remote host pub fn connect( &self, @@ -164,21 +65,21 @@ impl IoConnector { } } -impl Clone for IoConnector { +impl Clone for OpensslConnector { fn clone(&self) -> Self { - IoConnector { + OpensslConnector { connector: self.connector.clone(), openssl: self.openssl.clone(), } } } -impl ServiceFactory for IoConnector { +impl ServiceFactory for OpensslConnector { type Request = Connect; type Response = Io>; type Error = ConnectError; type Config = (); - type Service = IoConnector; + type Service = OpensslConnector; type InitError = (); type Future = Ready; @@ -187,7 +88,7 @@ impl ServiceFactory for IoConnector { } } -impl Service for IoConnector { +impl Service for OpensslConnector { type Request = Connect; type Response = Io>; type Error = ConnectError; diff --git a/ntex/src/connect/service.rs b/ntex/src/connect/service.rs index c495f386..38060033 100644 --- a/ntex/src/connect/service.rs +++ b/ntex/src/connect/service.rs @@ -1,14 +1,16 @@ use std::task::{Context, Poll}; use std::{collections::VecDeque, future::Future, io, net::SocketAddr, pin::Pin}; +use crate::io::Io; use crate::rt::net::TcpStream; use crate::service::{Service, ServiceFactory}; -use crate::util::{Either, Ready}; +use crate::util::{Either, PoolId, PoolRef, Ready}; use super::{Address, Connect, ConnectError, Resolver}; pub struct Connector { resolver: Resolver, + pool: PoolRef, } impl Connector { @@ -16,8 +18,18 @@ impl Connector { pub fn new() -> Self { Connector { resolver: Resolver::new(), + pool: PoolId::P0.pool_ref(), } } + + /// Set memory pool. + /// + /// Use specified memory pool for memory allocations. By default P0 + /// memory pool is used. + pub fn memory_pool(mut self, id: PoolId) -> Self { + self.pool = id.pool_ref(); + self + } } impl Connector { @@ -25,11 +37,14 @@ impl Connector { pub fn connect( &self, message: U, - ) -> impl Future> + ) -> impl Future> where Connect: From, { - ConnectServiceResponse::new(self.resolver.call(message.into())) + ConnectServiceResponse { + state: ConnectState::Resolve(self.resolver.call(message.into())), + pool: self.pool, + } } } @@ -43,13 +58,14 @@ impl Clone for Connector { fn clone(&self) -> Self { Connector { resolver: self.resolver.clone(), + pool: self.pool, } } } impl ServiceFactory for Connector { type Request = Connect; - type Response = TcpStream; + type Response = Io; type Error = ConnectError; type Config = (); type Service = Connector; @@ -64,7 +80,7 @@ impl ServiceFactory for Connector { impl Service for Connector { type Request = Connect; - type Response = TcpStream; + type Response = Io; type Error = ConnectError; type Future = ConnectServiceResponse; @@ -87,18 +103,20 @@ enum ConnectState { #[doc(hidden)] pub struct ConnectServiceResponse { state: ConnectState, + pool: PoolRef, } impl ConnectServiceResponse { pub(super) fn new(fut: as Service>::Future) -> Self { - ConnectServiceResponse { + Self { state: ConnectState::Resolve(fut), + pool: PoolId::P0.pool_ref(), } } } impl Future for ConnectServiceResponse { - type Output = Result; + type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.state { @@ -126,7 +144,12 @@ impl Future for ConnectServiceResponse { } } }, - ConnectState::Connect(ref mut fut) => Pin::new(fut).poll(cx), + ConnectState::Connect(ref mut fut) => match Pin::new(fut).poll(cx)? { + Poll::Pending => Poll::Pending, + Poll::Ready(stream) => { + Poll::Ready(Ok(Io::with_memory_pool(stream, self.pool))) + } + }, } } } diff --git a/ntex/src/framed/dispatcher.rs b/ntex/src/framed/dispatcher.rs deleted file mode 100644 index 416b8ae5..00000000 --- a/ntex/src/framed/dispatcher.rs +++ /dev/null @@ -1,901 +0,0 @@ -//! Framed transport dispatcher -use std::{ - cell::Cell, cell::RefCell, future::Future, pin::Pin, rc::Rc, task::Context, - task::Poll, time, time::Instant, -}; - -use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder}; -use crate::framed::{DispatchItem, Read, ReadTask, State, Timer, Write, WriteTask}; -use crate::service::{IntoService, Service}; -use crate::time::Seconds; -use crate::util::{Either, Pool}; - -type Response = ::Item; - -pin_project_lite::pin_project! { - /// Framed dispatcher - is a future that reads frames from Framed object - /// and pass then to the service. - pub struct Dispatcher - where - S: Service, Response = Option>>, - S::Error: 'static, - S::Future: 'static, - U: Encoder, - U: Decoder, - ::Item: 'static, - { - service: S, - inner: DispatcherInner, - #[pin] - fut: Option, - } -} - -struct DispatcherInner -where - S: Service, Response = Option>>, - U: Encoder + Decoder, -{ - st: Cell, - state: State, - timer: Timer, - ka_timeout: Seconds, - ka_updated: Cell, - error: Cell>, - shared: Rc>, - pool: Pool, -} - -struct DispatcherShared -where - S: Service, Response = Option>>, - U: Encoder + Decoder, -{ - codec: U, - error: Cell::Error>>>, - inflight: Cell, -} - -#[derive(Copy, Clone, Debug)] -enum DispatcherState { - Processing, - Backpressure, - Stop, - Shutdown, -} - -enum DispatcherError { - KeepAlive, - Encoder(U), - Service(S), -} - -enum PollService { - Item(DispatchItem), - ServiceError, - Ready, -} - -impl From> for DispatcherError { - fn from(err: Either) -> Self { - match err { - Either::Left(err) => DispatcherError::Service(err), - Either::Right(err) => DispatcherError::Encoder(err), - } - } -} - -impl Dispatcher -where - S: Service, Response = Option>> + 'static, - U: Decoder + Encoder + 'static, - ::Item: 'static, -{ - /// Construct new `Dispatcher` instance. - pub fn new>( - io: T, - codec: U, - state: State, - service: F, - timer: Timer, - ) -> Self - where - T: AsyncRead + AsyncWrite + Unpin + 'static, - { - let io = Rc::new(RefCell::new(io)); - - // start support tasks - crate::rt::spawn(ReadTask::new(io.clone(), state.clone())); - crate::rt::spawn(WriteTask::new(io, state.clone())); - - Self::from_state(codec, state, service, timer) - } - - /// Construct new `Dispatcher` instance. - pub fn from_state>( - codec: U, - state: State, - service: F, - timer: Timer, - ) -> Self { - let updated = timer.now(); - let ka_timeout = Seconds(30); - - // register keepalive timer - let expire = updated + time::Duration::from(ka_timeout); - timer.register(expire, expire, &state); - - Dispatcher { - service: service.into_service(), - fut: None, - inner: DispatcherInner { - pool: state.memory_pool().pool(), - ka_updated: Cell::new(updated), - error: Cell::new(None), - st: Cell::new(DispatcherState::Processing), - shared: Rc::new(DispatcherShared { - codec, - error: Cell::new(None), - inflight: Cell::new(0), - }), - state, - timer, - ka_timeout, - }, - } - } - - /// Set keep-alive timeout. - /// - /// To disable timeout set value to 0. - /// - /// By default keep-alive timeout is set to 30 seconds. - pub fn keepalive_timeout(mut self, timeout: Seconds) -> Self { - // register keepalive timer - let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka()); - if timeout.is_zero() { - self.inner.timer.unregister(prev, &self.inner.state); - } else { - let expire = self.inner.ka_updated.get() + time::Duration::from(timeout); - self.inner.timer.register(expire, prev, &self.inner.state); - } - self.inner.ka_timeout = timeout; - - self - } - - /// Set connection disconnect timeout in seconds. - /// - /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete - /// within this time, the connection get dropped. - /// - /// To disable timeout set value to 0. - /// - /// By default disconnect timeout is set to 1 seconds. - pub fn disconnect_timeout(self, val: Seconds) -> Self { - self.inner.state.set_disconnect_timeout(val); - self - } -} - -impl DispatcherShared -where - S: Service, Response = Option>>, - S::Error: 'static, - S::Future: 'static, - U: Encoder + Decoder, - ::Item: 'static, -{ - fn handle_result(&self, item: Result, write: Write<'_>) { - self.inflight.set(self.inflight.get() - 1); - match write.encode_result(item, &self.codec) { - Ok(true) => (), - Ok(false) => write.enable_backpressure(None), - Err(err) => self.error.set(Some(err.into())), - } - write.wake_dispatcher(); - } -} - -impl Future for Dispatcher -where - S: Service, Response = Option>> + 'static, - U: Decoder + Encoder + 'static, - ::Item: 'static, -{ - type Output = Result<(), S::Error>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut().project(); - let slf = &this.inner; - let state = &slf.state; - let read = state.read(); - let write = state.write(); - - // handle service response future - if let Some(fut) = this.fut.as_mut().as_pin_mut() { - match fut.poll(cx) { - Poll::Pending => (), - Poll::Ready(item) => { - this.fut.set(None); - slf.shared.inflight.set(slf.shared.inflight.get() - 1); - slf.handle_result(item, write); - } - } - } - - // handle memory pool pressure - if slf.pool.poll_ready(cx).is_pending() { - read.pause(cx.waker()); - return Poll::Pending; - } - - loop { - match slf.st.get() { - DispatcherState::Processing => { - let result = match slf.poll_service(this.service, cx, read) { - Poll::Pending => return Poll::Pending, - Poll::Ready(result) => result, - }; - - let item = match result { - PollService::Ready => { - if !write.is_ready() { - // instruct write task to notify dispatcher when data is flushed - write.enable_backpressure(Some(cx.waker())); - slf.st.set(DispatcherState::Backpressure); - DispatchItem::WBackPressureEnabled - } else if read.is_ready() { - // decode incoming bytes if buffer is ready - match read.decode(&slf.shared.codec) { - Ok(Some(el)) => { - slf.update_keepalive(); - DispatchItem::Item(el) - } - Ok(None) => { - log::trace!("not enough data to decode next frame, register dispatch task"); - read.wake(cx.waker()); - return Poll::Pending; - } - Err(err) => { - slf.st.set(DispatcherState::Stop); - slf.unregister_keepalive(); - DispatchItem::DecoderError(err) - } - } - } else { - // no new events - state.register_dispatcher(cx.waker()); - return Poll::Pending; - } - } - PollService::Item(item) => item, - PollService::ServiceError => continue, - }; - - // call service - if this.fut.is_none() { - // optimize first service call - this.fut.set(Some(this.service.call(item))); - match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) { - Poll::Ready(res) => { - this.fut.set(None); - slf.handle_result(res, write); - } - Poll::Pending => { - slf.shared.inflight.set(slf.shared.inflight.get() + 1) - } - } - } else { - slf.spawn_service_call(this.service.call(item)); - } - } - // handle write back-pressure - DispatcherState::Backpressure => { - let result = match slf.poll_service(this.service, cx, read) { - Poll::Ready(result) => result, - Poll::Pending => return Poll::Pending, - }; - let item = match result { - PollService::Ready => { - if write.is_ready() { - slf.st.set(DispatcherState::Processing); - DispatchItem::WBackPressureDisabled - } else { - return Poll::Pending; - } - } - PollService::Item(item) => item, - PollService::ServiceError => continue, - }; - - // call service - if this.fut.is_none() { - // optimize first service call - this.fut.set(Some(this.service.call(item))); - match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) { - Poll::Ready(res) => { - this.fut.set(None); - slf.handle_result(res, write); - } - Poll::Pending => { - slf.shared.inflight.set(slf.shared.inflight.get() + 1) - } - } - } else { - slf.spawn_service_call(this.service.call(item)); - } - } - // drain service responses - DispatcherState::Stop => { - // service may relay on poll_ready for response results - if !this.inner.state.is_dispatcher_ready_err() { - let _ = this.service.poll_ready(cx); - } - - if slf.shared.inflight.get() == 0 { - slf.st.set(DispatcherState::Shutdown); - state.shutdown_io(); - } else { - state.register_dispatcher(cx.waker()); - return Poll::Pending; - } - } - // shutdown service - DispatcherState::Shutdown => { - let err = slf.error.take(); - - return if this.service.poll_shutdown(cx, err.is_some()).is_ready() { - log::trace!("service shutdown is completed, stop"); - - Poll::Ready(if let Some(err) = err { - Err(err) - } else { - Ok(()) - }) - } else { - slf.error.set(err); - Poll::Pending - }; - } - } - } - } -} - -impl DispatcherInner -where - S: Service, Response = Option>> + 'static, - U: Decoder + Encoder + 'static, -{ - /// spawn service call - fn spawn_service_call(&self, fut: S::Future) { - self.shared.inflight.set(self.shared.inflight.get() + 1); - - let st = self.state.clone(); - let shared = self.shared.clone(); - crate::rt::spawn(async move { - let item = fut.await; - shared.handle_result(item, st.write()); - }); - } - - fn handle_result( - &self, - item: Result::Item>, S::Error>, - write: Write<'_>, - ) { - match write.encode_result(item, &self.shared.codec) { - Ok(true) => (), - Ok(false) => write.enable_backpressure(None), - Err(Either::Left(err)) => { - self.error.set(Some(err)); - } - Err(Either::Right(err)) => { - self.shared.error.set(Some(DispatcherError::Encoder(err))) - } - } - } - - fn poll_service( - &self, - srv: &S, - cx: &mut Context<'_>, - read: Read<'_>, - ) -> Poll> { - match srv.poll_ready(cx) { - Poll::Ready(Ok(_)) => { - // service is ready, wake io read task - read.resume(); - - // check keepalive timeout - self.check_keepalive(); - - // check for errors - Poll::Ready(if let Some(err) = self.shared.error.take() { - log::trace!("error occured, stopping dispatcher"); - self.unregister_keepalive(); - self.st.set(DispatcherState::Stop); - - match err { - DispatcherError::KeepAlive => { - PollService::Item(DispatchItem::KeepAliveTimeout) - } - DispatcherError::Encoder(err) => { - PollService::Item(DispatchItem::EncoderError(err)) - } - DispatcherError::Service(err) => { - self.error.set(Some(err)); - PollService::ServiceError - } - } - } else if self.state.is_dispatcher_stopped() { - log::trace!("dispatcher is instructed to stop"); - - self.unregister_keepalive(); - - // process unhandled data - if let Ok(Some(el)) = read.decode(&self.shared.codec) { - PollService::Item(DispatchItem::Item(el)) - } else { - self.st.set(DispatcherState::Stop); - - // get io error - if let Some(err) = self.state.take_io_error() { - PollService::Item(DispatchItem::IoError(err)) - } else { - PollService::ServiceError - } - } - } else { - PollService::Ready - }) - } - // pause io read task - Poll::Pending => { - log::trace!("service is not ready, register dispatch task"); - read.pause(cx.waker()); - Poll::Pending - } - // handle service readiness error - Poll::Ready(Err(err)) => { - log::trace!("service readiness check failed, stopping"); - self.st.set(DispatcherState::Stop); - self.error.set(Some(err)); - self.unregister_keepalive(); - self.state.dispatcher_ready_err(); - Poll::Ready(PollService::ServiceError) - } - } - } - - fn ka(&self) -> Seconds { - self.ka_timeout - } - - fn ka_enabled(&self) -> bool { - self.ka_timeout.non_zero() - } - - /// check keepalive timeout - fn check_keepalive(&self) { - if self.state.is_keepalive() { - log::trace!("keepalive timeout"); - if let Some(err) = self.shared.error.take() { - self.shared.error.set(Some(err)); - } else { - self.shared.error.set(Some(DispatcherError::KeepAlive)); - } - } - } - - /// update keep-alive timer - fn update_keepalive(&self) { - if self.ka_enabled() { - let updated = self.timer.now(); - if updated != self.ka_updated.get() { - let ka = time::Duration::from(self.ka()); - self.timer.register( - updated + ka, - self.ka_updated.get() + ka, - &self.state, - ); - self.ka_updated.set(updated); - } - } - } - - /// unregister keep-alive timer - fn unregister_keepalive(&self) { - if self.ka_enabled() { - self.timer.unregister( - self.ka_updated.get() + time::Duration::from(self.ka()), - &self.state, - ); - } - } -} - -#[cfg(test)] -mod tests { - use rand::Rng; - use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex}; - use std::time::Duration; - - use crate::codec::BytesCodec; - use crate::testing::Io; - use crate::time::{sleep, Millis}; - use crate::util::{Bytes, PoolRef, Ready}; - - use super::*; - - impl Dispatcher - where - S: Service, Response = Option>>, - S::Error: 'static, - S::Future: 'static, - U: Decoder + Encoder + 'static, - ::Item: 'static, - { - /// Construct new `Dispatcher` instance - pub(crate) fn debug>( - io: T, - codec: U, - service: F, - ) -> (Self, State) - where - T: AsyncRead + AsyncWrite + Unpin + 'static, - { - let timer = Timer::default(); - let ka_timeout = Seconds(1); - let ka_updated = timer.now(); - let state = State::new(); - let io = Rc::new(RefCell::new(io)); - let shared = Rc::new(DispatcherShared { - codec: codec, - error: Cell::new(None), - inflight: Cell::new(0), - }); - - let expire = ka_updated + Duration::from_millis(500); - timer.register(expire, expire, &state); - - crate::rt::spawn(ReadTask::new(io.clone(), state.clone())); - crate::rt::spawn(WriteTask::new(io.clone(), state.clone())); - - ( - Dispatcher { - service: service.into_service(), - fut: None, - inner: DispatcherInner { - shared, - timer, - ka_timeout, - ka_updated: Cell::new(ka_updated), - state: state.clone(), - error: Cell::new(None), - st: Cell::new(DispatcherState::Processing), - pool: state.memory_pool().pool(), - }, - }, - state, - ) - } - } - - #[crate::rt_test] - async fn test_basic() { - let (client, server) = Io::create(); - client.remote_buffer_cap(1024); - client.write("GET /test HTTP/1\r\n\r\n"); - - let (disp, _) = Dispatcher::debug( - server, - BytesCodec, - crate::service::fn_service(|msg: DispatchItem| async move { - sleep(Millis(50)).await; - if let DispatchItem::Item(msg) = msg { - Ok::<_, ()>(Some(msg.freeze())) - } else { - panic!() - } - }), - ); - crate::rt::spawn(async move { - let _ = disp.await; - }); - - sleep(Millis(25)).await; - let buf = client.read().await.unwrap(); - assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); - - client.write("GET /test HTTP/1\r\n\r\n"); - let buf = client.read().await.unwrap(); - assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); - - client.close().await; - assert!(client.is_server_dropped()); - } - - #[crate::rt_test] - async fn test_sink() { - let (client, server) = Io::create(); - client.remote_buffer_cap(1024); - client.write("GET /test HTTP/1\r\n\r\n"); - - let (disp, st) = Dispatcher::debug( - server, - BytesCodec, - crate::service::fn_service(|msg: DispatchItem| async move { - if let DispatchItem::Item(msg) = msg { - Ok::<_, ()>(Some(msg.freeze())) - } else { - panic!() - } - }), - ); - crate::rt::spawn(async move { - let _ = disp.disconnect_timeout(Seconds(1)).await; - }); - - let buf = client.read().await.unwrap(); - assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); - - assert!(st - .write() - .encode(Bytes::from_static(b"test"), &mut BytesCodec) - .is_ok()); - let buf = client.read().await.unwrap(); - assert_eq!(buf, Bytes::from_static(b"test")); - - st.close(); - sleep(Millis(1100)).await; - assert!(client.is_server_dropped()); - } - - #[crate::rt_test] - async fn test_err_in_service() { - let (client, server) = Io::create(); - client.remote_buffer_cap(0); - client.write("GET /test HTTP/1\r\n\r\n"); - - let (disp, state) = Dispatcher::debug( - server, - BytesCodec, - crate::service::fn_service(|_: DispatchItem| async move { - Err::, _>(()) - }), - ); - state - .write() - .encode( - Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), - &mut BytesCodec, - ) - .unwrap(); - crate::rt::spawn(async move { - let _ = disp.await; - }); - - // buffer should be flushed - client.remote_buffer_cap(1024); - let buf = client.read().await.unwrap(); - assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); - - // write side must be closed, dispatcher waiting for read side to close - assert!(client.is_closed()); - - // close read side - client.close().await; - assert!(client.is_server_dropped()); - } - - #[crate::rt_test] - async fn test_err_in_service_ready() { - let (client, server) = Io::create(); - client.remote_buffer_cap(0); - client.write("GET /test HTTP/1\r\n\r\n"); - - let counter = Rc::new(Cell::new(0)); - - struct Srv(Rc>); - - impl Service for Srv { - type Request = DispatchItem; - type Response = Option>; - type Error = (); - type Future = Ready>, ()>; - - fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { - self.0.set(self.0.get() + 1); - Poll::Ready(Err(())) - } - - fn call(&self, _: DispatchItem) -> Self::Future { - Ready::Ok(None) - } - } - - let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone())); - state - .write() - .encode( - Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), - &mut BytesCodec, - ) - .unwrap(); - crate::rt::spawn(async move { - let _ = disp.await; - }); - - // buffer should be flushed - client.remote_buffer_cap(1024); - let buf = client.read().await.unwrap(); - assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); - - // write side must be closed, dispatcher waiting for read side to close - assert!(client.is_closed()); - - // close read side - client.close().await; - assert!(client.is_server_dropped()); - - // service must be checked for readiness only once - assert_eq!(counter.get(), 1); - } - - #[crate::rt_test] - async fn test_write_backpressure() { - let (client, server) = Io::create(); - // do not allow to write to socket - client.remote_buffer_cap(0); - client.write("GET /test HTTP/1\r\n\r\n"); - - let data = Arc::new(Mutex::new(RefCell::new(Vec::new()))); - let data2 = data.clone(); - - let (disp, state) = Dispatcher::debug( - server, - BytesCodec, - crate::service::fn_service(move |msg: DispatchItem| { - let data = data2.clone(); - async move { - match msg { - DispatchItem::Item(_) => { - data.lock().unwrap().borrow_mut().push(0); - let bytes = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(65_536) - .map(char::from) - .collect::(); - return Ok::<_, ()>(Some(Bytes::from(bytes))); - } - DispatchItem::WBackPressureEnabled => { - data.lock().unwrap().borrow_mut().push(1); - } - DispatchItem::WBackPressureDisabled => { - data.lock().unwrap().borrow_mut().push(2); - } - _ => (), - } - Ok(None) - } - }), - ); - let pool = PoolRef::default(); - pool.set_read_params(8 * 1024, 1024); - pool.set_write_params(16 * 1024, 1024); - crate::rt::spawn(async move { - let _ = disp.await; - }); - - let buf = client.read_any(); - assert_eq!(buf, Bytes::from_static(b"")); - client.write("GET /test HTTP/1\r\n\r\n"); - sleep(Millis(25)).await; - - // buf must be consumed - assert_eq!(client.remote_buffer(|buf| buf.len()), 0); - - // response message - assert!(!state.write().is_ready()); - assert_eq!(state.write().with_buf(|buf| buf.len()), 65536); - - client.remote_buffer_cap(10240); - sleep(Millis(50)).await; - assert_eq!(state.write().with_buf(|buf| buf.len()), 55296); - - client.remote_buffer_cap(45056); - sleep(Millis(50)).await; - assert_eq!(state.write().with_buf(|buf| buf.len()), 10240); - - // backpressure disabled - assert!(state.write().is_ready()); - assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]); - } - - #[crate::rt_test] - async fn test_keepalive() { - let (client, server) = Io::create(); - // do not allow to write to socket - client.remote_buffer_cap(1024); - client.write("GET /test HTTP/1\r\n\r\n"); - - let data = Arc::new(Mutex::new(RefCell::new(Vec::new()))); - let data2 = data.clone(); - - let (disp, state) = Dispatcher::debug( - server, - BytesCodec, - crate::service::fn_service(move |msg: DispatchItem| { - let data = data2.clone(); - async move { - match msg { - DispatchItem::Item(bytes) => { - data.lock().unwrap().borrow_mut().push(0); - return Ok::<_, ()>(Some(bytes.freeze())); - } - DispatchItem::KeepAliveTimeout => { - data.lock().unwrap().borrow_mut().push(1); - } - _ => (), - } - Ok(None) - } - }), - ); - crate::rt::spawn(async move { - let _ = disp - .keepalive_timeout(Seconds::ZERO) - .keepalive_timeout(Seconds(1)) - .await; - }); - - state.set_disconnect_timeout(Seconds(1)); - - let buf = client.read().await.unwrap(); - assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); - sleep(Millis(3500)).await; - - // write side must be closed, dispatcher should fail with keep-alive - let flags = state.flags(); - assert!(state.is_io_err()); - assert!(state.is_io_shutdown()); - assert!(flags.contains(crate::framed::state::Flags::IO_SHUTDOWN)); - assert!(client.is_closed()); - assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]); - } - - #[crate::rt_test] - async fn test_unhandled_data() { - let handled = Arc::new(AtomicBool::new(false)); - let handled2 = handled.clone(); - - let (client, server) = Io::create(); - client.remote_buffer_cap(1024); - client.write("GET /test HTTP/1\r\n\r\n"); - - let (disp, _) = Dispatcher::debug( - server, - BytesCodec, - crate::service::fn_service(move |msg: DispatchItem| { - handled2.store(true, Relaxed); - async move { - sleep(Millis(50)).await; - if let DispatchItem::Item(msg) = msg { - Ok::<_, ()>(Some(msg.freeze())) - } else { - panic!() - } - } - }), - ); - client.close().await; - crate::rt::spawn(async move { - let _ = disp.await; - }); - sleep(Millis(50)).await; - - assert!(handled.load(Relaxed)); - } -} diff --git a/ntex/src/framed/mod.rs b/ntex/src/framed/mod.rs deleted file mode 100644 index 3054c448..00000000 --- a/ntex/src/framed/mod.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::{fmt, io}; - -mod dispatcher; -mod read; -mod state; -mod time; -mod write; - -pub use self::dispatcher::Dispatcher; -pub use self::read::ReadTask; -pub use self::state::{OnDisconnect, Read, State, Write}; -pub use self::time::Timer; -pub use self::write::WriteTask; - -use crate::codec::{Decoder, Encoder}; - -/// Framed transport item -pub enum DispatchItem { - Item(::Item), - /// Write back-pressure enabled - WBackPressureEnabled, - /// Write back-pressure disabled - WBackPressureDisabled, - /// Keep alive timeout - KeepAliveTimeout, - /// Decoder parse error - DecoderError(::Error), - /// Encoder parse error - EncoderError(::Error), - /// Unexpected io error - IoError(io::Error), -} - -impl fmt::Debug for DispatchItem -where - U: Encoder + Decoder, - ::Item: fmt::Debug, -{ - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - DispatchItem::Item(ref item) => { - write!(fmt, "DispatchItem::Item({:?})", item) - } - DispatchItem::WBackPressureEnabled => { - write!(fmt, "DispatchItem::WBackPressureEnabled") - } - DispatchItem::WBackPressureDisabled => { - write!(fmt, "DispatchItem::WBackPressureDisabled") - } - DispatchItem::KeepAliveTimeout => { - write!(fmt, "DispatchItem::KeepAliveTimeout") - } - DispatchItem::EncoderError(ref e) => { - write!(fmt, "DispatchItem::EncoderError({:?})", e) - } - DispatchItem::DecoderError(ref e) => { - write!(fmt, "DispatchItem::DecoderError({:?})", e) - } - DispatchItem::IoError(ref e) => { - write!(fmt, "DispatchItem::IoError({:?})", e) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::codec::BytesCodec; - - #[test] - fn test_fmt() { - type T = DispatchItem; - - let err = T::EncoderError(io::Error::new(io::ErrorKind::Other, "err")); - assert!(format!("{:?}", err).contains("DispatchItem::Encoder")); - let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err")); - assert!(format!("{:?}", err).contains("DispatchItem::Decoder")); - let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err")); - assert!(format!("{:?}", err).contains("DispatchItem::IoError")); - - assert!(format!("{:?}", T::WBackPressureEnabled) - .contains("DispatchItem::WBackPressureEnabled")); - assert!(format!("{:?}", T::WBackPressureDisabled) - .contains("DispatchItem::WBackPressureDisabled")); - assert!(format!("{:?}", T::KeepAliveTimeout) - .contains("DispatchItem::KeepAliveTimeout")); - } -} diff --git a/ntex/src/framed/read.rs b/ntex/src/framed/read.rs deleted file mode 100644 index 9b908f43..00000000 --- a/ntex/src/framed/read.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::{cell::RefCell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll}; - -use crate::codec::{AsyncRead, AsyncWrite}; -use crate::framed::State; - -/// Read io task -pub struct ReadTask -where - T: AsyncRead + AsyncWrite + Unpin, -{ - io: Rc>, - state: State, -} - -impl ReadTask -where - T: AsyncRead + AsyncWrite + Unpin, -{ - /// Create new read io task - pub fn new(io: Rc>, state: State) -> Self { - Self { io, state } - } -} - -impl Future for ReadTask -where - T: AsyncRead + AsyncWrite + Unpin, -{ - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.state.is_io_shutdown() { - log::trace!("read task is instructed to shutdown"); - Poll::Ready(()) - } else if self.state.is_io_stop() { - self.state.wake_dispatcher(); - Poll::Ready(()) - } else if self.state.is_read_paused() { - self.state.register_read_task(cx.waker()); - Poll::Pending - } else { - let mut io = self.io.borrow_mut(); - if self.state.read_io(&mut *io, cx) { - Poll::Pending - } else { - Poll::Ready(()) - } - } - } -} diff --git a/ntex/src/framed/state.rs b/ntex/src/framed/state.rs deleted file mode 100644 index 7a023937..00000000 --- a/ntex/src/framed/state.rs +++ /dev/null @@ -1,1141 +0,0 @@ -//! Framed transport dispatcher -use std::task::{Context, Poll, Waker}; -use std::{cell::Cell, cell::RefCell, future::Future, hash, io, pin::Pin, rc::Rc}; - -use slab::Slab; - -use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; -use crate::task::LocalWaker; -use crate::time::Seconds; -use crate::util::{poll_fn, Buf, BytesMut, Either, PoolId, PoolRef}; - -bitflags::bitflags! { - pub struct Flags: u16 { - /// io error occured - const IO_ERR = 0b0000_0001; - /// stop io tasks - const IO_STOP = 0b0000_0010; - /// shutdown io tasks - const IO_SHUTDOWN = 0b0000_0100; - - /// pause io read - const RD_PAUSED = 0b0000_1000; - /// new data is available - const RD_READY = 0b0001_0000; - /// read buffer is full - const RD_BUF_FULL = 0b0010_0000; - - /// write buffer is full - const WR_BACKPRESSURE = 0b0001_0000_0000; - - /// dispatcher is marked stopped - const DSP_STOP = 0b0001_0000_0000_0000; - /// keep-alive timeout occured - const DSP_KEEPALIVE = 0b0010_0000_0000_0000; - /// dispatcher returned error - const DSP_ERR = 0b0100_0000_0000_0000; - /// dispatcher rediness error - const DSP_READY_ERR = 0b1000_0000_0000_0000; - } -} - -pub struct State(Rc); - -pub(crate) struct IoStateInner { - flags: Cell, - pool: Cell, - disconnect_timeout: Cell, - error: Cell>, - read_task: LocalWaker, - write_task: LocalWaker, - dispatch_task: LocalWaker, - read_buf: Cell>, - write_buf: Cell>, - on_disconnect: RefCell>>, -} - -impl IoStateInner { - fn insert_flags(&self, f: Flags) { - let mut flags = self.flags.get(); - flags.insert(f); - self.flags.set(flags); - } - - fn remove_flags(&self, f: Flags) { - let mut flags = self.flags.get(); - flags.remove(f); - self.flags.set(flags); - } - - fn get_read_buf(&self) -> BytesMut { - if let Some(buf) = self.read_buf.take() { - buf - } else { - self.pool.get().get_read_buf() - } - } - - fn get_write_buf(&self) -> BytesMut { - if let Some(buf) = self.write_buf.take() { - buf - } else { - self.pool.get().get_write_buf() - } - } - - fn release_read_buf(&self, buf: BytesMut) { - if buf.is_empty() { - self.pool.get().release_read_buf(buf); - } else { - self.read_buf.set(Some(buf)); - } - } - - fn release_write_buf(&self, buf: BytesMut) { - if buf.is_empty() { - self.pool.get().release_write_buf(buf); - } else { - self.write_buf.set(Some(buf)); - } - } -} - -impl Drop for IoStateInner { - fn drop(&mut self) { - if let Some(buf) = self.read_buf.take() { - self.pool.get().release_read_buf(buf); - } - if let Some(buf) = self.write_buf.take() { - self.pool.get().release_write_buf(buf); - } - } -} - -impl Clone for State { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -impl Eq for State {} - -impl PartialEq for State { - fn eq(&self, other: &Self) -> bool { - Rc::as_ptr(&self.0) == Rc::as_ptr(&other.0) - } -} - -impl hash::Hash for State { - fn hash(&self, state: &mut H) { - Rc::as_ptr(&self.0).hash(state); - } -} - -impl State { - #[inline] - /// Create `State` instance - pub fn new() -> Self { - Self::with_memory_pool(PoolId::DEFAULT.pool_ref()) - } - - #[inline] - /// Create `State` instance with specific memory pool. - pub fn with_memory_pool(pool: PoolRef) -> Self { - State(Rc::new(IoStateInner { - pool: Cell::new(pool), - flags: Cell::new(Flags::empty()), - error: Cell::new(None), - disconnect_timeout: Cell::new(Seconds(1)), - dispatch_task: LocalWaker::new(), - read_task: LocalWaker::new(), - write_task: LocalWaker::new(), - read_buf: Cell::new(None), - write_buf: Cell::new(None), - on_disconnect: RefCell::new(Slab::new()), - })) - } - - #[inline] - /// Create `State` from Framed - pub fn from_framed(framed: Framed) -> (Io, U, Self) { - let pool = PoolId::DEFAULT.pool_ref(); - let mut parts = framed.into_parts(); - let read_buf = if !parts.read_buf.is_empty() { - pool.move_in(&mut parts.read_buf); - Cell::new(Some(parts.read_buf)) - } else { - Cell::new(None) - }; - let write_buf = if !parts.write_buf.is_empty() { - pool.move_in(&mut parts.write_buf); - Cell::new(Some(parts.write_buf)) - } else { - Cell::new(None) - }; - - let state = State(Rc::new(IoStateInner { - read_buf, - write_buf, - pool: Cell::new(pool), - flags: Cell::new(Flags::empty()), - error: Cell::new(None), - disconnect_timeout: Cell::new(Seconds(1)), - dispatch_task: LocalWaker::new(), - read_task: LocalWaker::new(), - write_task: LocalWaker::new(), - on_disconnect: RefCell::new(Slab::new()), - })); - (parts.io, parts.codec, state) - } - - #[doc(hidden)] - #[inline] - /// Create `State` instance with custom params - pub fn with_params( - _max_read_buf_size: u16, - _max_write_buf_size: u16, - _min_buf_size: u16, - disconnect_timeout: Seconds, - ) -> Self { - State(Rc::new(IoStateInner { - pool: Cell::new(PoolId::DEFAULT.pool_ref()), - flags: Cell::new(Flags::empty()), - error: Cell::new(None), - disconnect_timeout: Cell::new(disconnect_timeout), - dispatch_task: LocalWaker::new(), - read_buf: Cell::new(None), - read_task: LocalWaker::new(), - write_buf: Cell::new(None), - write_task: LocalWaker::new(), - on_disconnect: RefCell::new(Slab::new()), - })) - } - - #[inline] - /// Convert State to a Framed instance - pub fn into_framed(self, io: Io, codec: U) -> Framed { - let mut parts = FramedParts::new(io, codec); - - parts.read_buf = if let Some(buf) = self.0.read_buf.take() { - buf - } else { - BytesMut::new() - }; - parts.write_buf = if let Some(buf) = self.0.write_buf.take() { - buf - } else { - BytesMut::new() - }; - Framed::from_parts(parts) - } - - pub(crate) fn keepalive_timeout(&self) { - let state = self.0.as_ref(); - state.dispatch_task.wake(); - state.insert_flags(Flags::DSP_KEEPALIVE); - } - - pub(super) fn get_disconnect_timeout(&self) -> Seconds { - self.0.disconnect_timeout.get() - } - - fn insert_flags(&self, f: Flags) { - let mut flags = self.0.flags.get(); - flags.insert(f); - self.0.flags.set(flags); - } - - fn remove_flags(&self, f: Flags) { - let mut flags = self.0.flags.get(); - flags.remove(f); - self.0.flags.set(flags); - } - - #[inline] - #[doc(hidden)] - /// Get current state flags - pub fn flags(&self) -> Flags { - self.0.flags.get() - } - - #[inline] - /// Get memory pool - pub fn memory_pool(&self) -> PoolRef { - self.0.pool.get() - } - - #[inline] - /// Set memory pool - pub fn set_memory_pool(&self, pool: PoolRef) { - if let Some(mut buf) = self.0.read_buf.take() { - pool.move_in(&mut buf); - self.0.read_buf.set(Some(buf)); - } - if let Some(mut buf) = self.0.write_buf.take() { - pool.move_in(&mut buf); - self.0.write_buf.set(Some(buf)); - } - self.0.pool.set(pool) - } - - #[doc(hidden)] - #[deprecated(since = "0.4.11", note = "Use memory pool config")] - #[inline] - /// Set read/write buffer sizes - /// - /// By default read max buf size is 8kb, write max buf size is 8kb - pub fn set_buffer_params( - &self, - _max_read_buf_size: u16, - _max_write_buf_size: u16, - _min_buf_size: u16, - ) { - } - - #[inline] - /// Set io disconnect timeout in secs - pub fn set_disconnect_timeout(&self, timeout: Seconds) { - self.0.disconnect_timeout.set(timeout) - } - - #[inline] - /// Notify when socket get disconnected - pub fn on_disconnect(&self) -> OnDisconnect { - OnDisconnect::new(self.0.clone(), self.0.flags.get().contains(Flags::IO_ERR)) - } - - fn notify_disconnect(&self) { - let mut slab = self.0.on_disconnect.borrow_mut(); - for item in slab.iter_mut() { - if let Some(waker) = item.1 { - waker.wake(); - } else { - *item.1 = Some(LocalWaker::default()) - } - } - } - - #[inline] - /// Check if io error occured in read or write task - pub fn is_io_err(&self) -> bool { - self.0.flags.get().contains(Flags::IO_ERR) - } - - pub(super) fn is_io_shutdown(&self) -> bool { - self.0 - .flags - .get() - .intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) - } - - pub(super) fn is_io_stop(&self) -> bool { - self.0.flags.get().contains(Flags::IO_STOP) - } - - pub(super) fn is_read_paused(&self) -> bool { - self.0.flags.get().contains(Flags::RD_PAUSED) - } - - #[inline] - /// Check if keep-alive timeout occured - pub fn is_keepalive(&self) -> bool { - self.0.flags.get().contains(Flags::DSP_KEEPALIVE) - } - - #[inline] - /// Check if dispatcher marked stopped - pub fn is_dispatcher_stopped(&self) -> bool { - self.0.flags.get().contains(Flags::DSP_STOP) - } - - #[inline] - /// Check if dispatcher failed readiness check - pub fn is_dispatcher_ready_err(&self) -> bool { - self.0.flags.get().contains(Flags::DSP_READY_ERR) - } - - #[inline] - pub fn is_open(&self) -> bool { - !self - .0 - .flags - .get() - .intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::DSP_STOP) - } - - pub(crate) fn set_io_error(&self, err: Option) { - self.0.error.set(err); - self.0.read_task.wake(); - self.0.write_task.wake(); - self.0.dispatch_task.wake(); - self.insert_flags(Flags::IO_ERR | Flags::DSP_STOP); - self.notify_disconnect(); - } - - pub(super) fn set_wr_shutdown_complete(&self) { - if !self.0.flags.get().contains(Flags::IO_ERR) { - self.notify_disconnect(); - self.insert_flags(Flags::IO_ERR); - self.0.read_task.wake(); - } - } - - pub(super) fn register_read_task(&self, waker: &Waker) { - self.0.read_task.register(waker); - } - - #[inline] - /// Stop io tasks - /// - /// Wake dispatcher when read or write task is stopped. - pub fn stop_io(&self, waker: &Waker) { - self.insert_flags(Flags::IO_STOP); - self.0.read_task.wake(); - self.0.write_task.wake(); - self.0.dispatch_task.register(waker); - } - - #[inline] - /// Gracefully shutdown read and write io tasks - pub fn shutdown_io(&self) { - let flags = self.0.flags.get(); - - if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { - log::trace!("initiate io shutdown {:?}", flags); - self.insert_flags(Flags::IO_SHUTDOWN); - self.0.read_task.wake(); - self.0.write_task.wake(); - } - } - - #[inline] - /// Take io error if any occured - pub fn take_io_error(&self) -> Option { - self.0.error.take() - } - - #[inline] - /// Reset io stop flags - pub fn reset_io_stop(&self) { - self.remove_flags(Flags::IO_STOP); - } - - #[inline] - /// Reset keep-alive error - pub fn reset_keepalive(&self) { - self.remove_flags(Flags::DSP_KEEPALIVE) - } - - #[inline] - /// Wake dispatcher task - pub fn wake_dispatcher(&self) { - self.0.dispatch_task.wake(); - } - - #[inline] - /// Register dispatcher task - pub fn register_dispatcher(&self, waker: &Waker) { - self.0.dispatch_task.register(waker); - } - - #[inline] - /// Mark dispatcher as stopped - pub fn dispatcher_stopped(&self) { - self.insert_flags(Flags::DSP_STOP); - } - - #[inline] - /// Mark dispatcher as failed readiness check - pub fn dispatcher_ready_err(&self) { - self.insert_flags(Flags::DSP_READY_ERR); - } - - #[inline] - /// Get api for read task - pub fn read(&'_ self) -> Read<'_> { - Read(self.0.as_ref()) - } - - #[inline] - /// Get api for write task - pub fn write(&'_ self) -> Write<'_> { - Write(self.0.as_ref()) - } - - #[inline] - /// Gracefully close connection - /// - /// First stop dispatcher, then dispatcher stops io tasks - pub fn close(&self) { - self.insert_flags(Flags::DSP_STOP); - self.0.dispatch_task.wake(); - } - - #[inline] - /// Force close connection - /// - /// Dispatcher does not wait for uncompleted responses, but flushes io buffers. - pub fn force_close(&self) { - log::trace!("force close framed object"); - self.insert_flags(Flags::DSP_STOP | Flags::IO_SHUTDOWN); - self.0.read_task.wake(); - self.0.write_task.wake(); - self.0.dispatch_task.wake(); - } -} - -impl State { - #[inline] - /// Read incoming io stream and decode codec item. - pub async fn next( - &self, - io: &mut T, - codec: &U, - ) -> Result, Either> - where - T: AsyncRead + AsyncWrite + Unpin, - U: Decoder, - { - let mut buf = self.0.get_read_buf(); - - loop { - let item = codec.decode(&mut buf); - let result = match item { - Ok(Some(el)) => Ok(Some(el)), - Ok(None) => { - let n = poll_fn(|cx| { - crate::codec::poll_read_buf(Pin::new(&mut *io), cx, &mut buf) - }) - .await - .map_err(Either::Right)?; - if n == 0 { - Ok(None) - } else { - continue; - } - } - Err(err) => { - self.set_io_error(None); - Err(Either::Left(err)) - } - }; - self.0.release_read_buf(buf); - return result; - } - } - - #[inline] - /// Encode item, send to a peer and then flush - pub async fn send( - &self, - io: &mut T, - codec: &U, - item: U::Item, - ) -> Result<(), Either> - where - T: AsyncRead + AsyncWrite + Unpin, - U: Encoder, - { - let mut buf = self.0.get_write_buf(); - codec.encode(item, &mut buf).map_err(Either::Left)?; - - self.0.write_buf.set(Some(buf)); - if !poll_fn(|cx| self.flush_io(io, cx)).await { - let err = self.0.error.take().unwrap_or_else(|| { - io::Error::new(io::ErrorKind::Other, "Internal error") - }); - Err(Either::Right(err)) - } else { - Ok(()) - } - } - - #[inline] - pub fn poll_next( - &self, - io: &mut T, - codec: &U, - cx: &mut Context<'_>, - ) -> Poll, Either>> - where - T: AsyncRead + AsyncWrite + Unpin, - U: Decoder, - { - let mut buf = self.0.get_read_buf(); - - loop { - let item = match codec.decode(&mut buf) { - Ok(Some(el)) => Poll::Ready(Ok(Some(el))), - Ok(None) => { - match crate::codec::poll_read_buf(Pin::new(&mut *io), cx, &mut buf) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(err)) => Poll::Ready(Err(Either::Right(err))), - Poll::Ready(Ok(n)) => { - if n == 0 { - Poll::Ready(Ok(None)) - } else { - continue; - } - } - } - } - Err(err) => { - self.set_io_error(None); - Poll::Ready(Err(Either::Left(err))) - } - }; - self.0.release_read_buf(buf); - return item; - } - } - - /// read data from io steram and update internal state - pub(super) fn read_io(&self, io: &mut T, cx: &mut Context<'_>) -> bool - where - T: AsyncRead + AsyncWrite + Unpin, - { - let inner = self.0.as_ref(); - let (hw, lw) = inner.pool.get().read_params().unpack(); - let mut buf = inner.get_read_buf(); - - // read data from socket - let mut updated = false; - loop { - // make sure we've got room - let remaining = buf.capacity() - buf.len(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - match crate::codec::poll_read_buf(Pin::new(&mut *io), cx, &mut buf) { - Poll::Pending => break, - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("io stream is disconnected"); - inner.release_read_buf(buf); - self.set_io_error(None); - return false; - } else { - if buf.len() > hw { - log::trace!( - "buffer is too large {}, enable read back-pressure", - buf.len() - ); - inner.dispatch_task.wake(); - inner.read_buf.set(Some(buf)); - inner.read_task.register(cx.waker()); - inner.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL); - return true; - } - - updated = true; - } - } - Poll::Ready(Err(err)) => { - log::trace!("read task failed on io {:?}", err); - inner.release_read_buf(buf); - self.set_io_error(Some(err)); - return false; - } - } - } - - if updated { - inner.read_buf.set(Some(buf)); - self.insert_flags(Flags::RD_READY); - self.0.dispatch_task.wake(); - } else { - inner.release_read_buf(buf); - } - self.0.read_task.register(cx.waker()); - true - } - - /// Flush write buffer to underlying I/O stream. - pub(super) fn flush_io(&self, io: &mut T, cx: &mut Context<'_>) -> Poll - where - T: AsyncRead + AsyncWrite + Unpin, - { - let inner = self.0.as_ref(); - let mut buf = if let Some(buf) = inner.write_buf.take() { - buf - } else { - self.0.write_task.register(cx.waker()); - return Poll::Ready(true); - }; - let len = buf.len(); - - if len != 0 { - // log::trace!("flushing framed transport: {}", len); - - let mut written = 0; - while written < len { - match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { - Poll::Pending => break, - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!( - "Disconnected during flush, written {}", - written - ); - buf.clear(); - inner.release_write_buf(buf); - self.set_io_error(Some(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))); - return Poll::Ready(false); - } else { - written += n - } - } - Poll::Ready(Err(e)) => { - log::trace!("Error during flush: {}", e); - buf.clear(); - inner.release_write_buf(buf); - self.set_io_error(Some(e)); - return Poll::Ready(false); - } - } - } - // log::trace!("flushed {} bytes", written); - - // remove written data - if written == len { - buf.clear() - } else { - buf.advance(written); - } - } - - // if write buffer is smaller than high watermark value, turn off back-pressure - if buf.len() < self.0.pool.get().write_params_high() { - let mut flags = self.0.flags.get(); - if flags.contains(Flags::WR_BACKPRESSURE) { - flags.remove(Flags::WR_BACKPRESSURE); - self.0.flags.set(flags); - self.0.dispatch_task.wake(); - } - } else { - self.insert_flags(Flags::WR_BACKPRESSURE); - } - self.0.write_task.register(cx.waker()); - - // flush - let result = match Pin::new(&mut *io).poll_flush(cx) { - Poll::Ready(Ok(_)) => { - if buf.is_empty() { - Poll::Ready(true) - } else { - Poll::Pending - } - } - Poll::Pending => Poll::Pending, - Poll::Ready(Err(err)) => { - log::trace!("Error during flush: {}", err); - self.set_io_error(Some(err)); - Poll::Ready(false) - } - }; - inner.release_write_buf(buf); - result - } -} - -#[derive(Copy, Clone)] -pub struct Write<'a>(&'a IoStateInner); - -impl<'a> Write<'a> { - #[inline] - /// Check if write task is ready - pub fn is_ready(&self) -> bool { - !self.0.flags.get().contains(Flags::WR_BACKPRESSURE) - } - - #[inline] - /// Check if write buffer is full - pub fn is_full(&self) -> bool { - if let Some(buf) = self.0.read_buf.take() { - let hw = self.0.pool.get().write_params_high(); - let result = buf.len() >= hw; - self.0.write_buf.set(Some(buf)); - result - } else { - false - } - } - - #[inline] - /// Wait until write task flushes data to io stream - /// - /// Write task must be waken up separately. - pub fn enable_backpressure(&self, waker: Option<&Waker>) { - log::trace!("enable write back-pressure"); - self.0.insert_flags(Flags::WR_BACKPRESSURE); - if let Some(waker) = waker { - self.0.dispatch_task.register(waker); - } - } - - #[inline] - /// Wake dispatcher task - pub fn wake_dispatcher(&self) { - self.0.dispatch_task.wake(); - } - - /// Get mut access to write buffer - pub fn with_buf(&self, f: F) -> R - where - F: FnOnce(&mut BytesMut) -> R, - { - let mut buf = self.0.get_write_buf(); - if buf.is_empty() { - self.0.write_task.wake(); - } - - let result = f(&mut buf); - self.0.release_write_buf(buf); - result - } - - #[inline] - /// Write item to a buffer and wake up write task - /// - /// Returns write buffer state, false is returned if write buffer if full. - pub fn encode( - &self, - item: U::Item, - codec: &U, - ) -> Result::Error> - where - U: Encoder, - { - let flags = self.0.flags.get(); - - if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { - let mut buf = self.0.get_write_buf(); - let is_write_sleep = buf.is_empty(); - let (hw, lw) = self.0.pool.get().write_params().unpack(); - - // make sure we've got room - let remaining = buf.capacity() - buf.len(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - // encode item and wake write task - let result = codec.encode(item, &mut buf).map(|_| { - if is_write_sleep { - self.0.write_task.wake(); - } - buf.len() < hw - }); - self.0.write_buf.set(Some(buf)); - result - } else { - Ok(true) - } - } - - #[inline] - /// Write item to a buf and wake up io task - pub fn encode_result( - &self, - item: Result, E>, - codec: &U, - ) -> Result> - where - U: Encoder, - { - let flags = self.0.flags.get(); - - if !flags.intersects(Flags::IO_ERR | Flags::DSP_ERR) { - match item { - Ok(Some(item)) => { - let mut buf = self.0.get_write_buf(); - let is_write_sleep = buf.is_empty(); - let (hw, lw) = self.0.pool.get().write_params().unpack(); - - // make sure we've got room - let remaining = buf.capacity() - buf.len(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - // encode item - if let Err(err) = codec.encode(item, &mut buf) { - log::trace!("Encoder error: {:?}", err); - self.0.release_write_buf(buf); - self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR); - self.0.dispatch_task.wake(); - return Err(Either::Right(err)); - } else if is_write_sleep { - self.0.write_task.wake(); - } - let result = Ok(buf.len() < hw); - self.0.write_buf.set(Some(buf)); - result - } - Err(err) => { - self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR); - self.0.dispatch_task.wake(); - Err(Either::Left(err)) - } - _ => Ok(true), - } - } else { - Ok(true) - } - } -} - -#[derive(Copy, Clone)] -pub struct Read<'a>(&'a IoStateInner); - -impl<'a> Read<'a> { - #[inline] - /// Check if read buffer has new data - pub fn is_ready(&self) -> bool { - self.0.flags.get().contains(Flags::RD_READY) - } - - #[inline] - /// Check if read buffer is full - pub fn is_full(&self) -> bool { - if let Some(buf) = self.0.read_buf.take() { - let result = buf.len() >= self.0.pool.get().read_params_high(); - self.0.read_buf.set(Some(buf)); - result - } else { - false - } - } - - #[inline] - /// Pause read task - /// - /// Also register dispatch task - pub fn pause(&self, waker: &Waker) { - self.0.insert_flags(Flags::RD_PAUSED); - self.0.dispatch_task.register(waker); - } - - #[inline] - /// Wake read io task if it is paused - pub fn resume(&self) -> bool { - let flags = self.0.flags.get(); - if flags.contains(Flags::RD_PAUSED) { - self.0.remove_flags(Flags::RD_PAUSED); - self.0.read_task.wake(); - true - } else { - false - } - } - - #[inline] - /// Wake read task and instruct to read more data - /// - /// Only wakes if back-pressure is enabled on read task - /// otherwise read is already awake. - pub fn wake(&self, waker: &Waker) { - let mut flags = self.0.flags.get(); - flags.remove(Flags::RD_READY); - if flags.contains(Flags::RD_BUF_FULL) { - log::trace!("read back-pressure is enabled, wake io task"); - flags.remove(Flags::RD_BUF_FULL); - self.0.read_task.wake(); - } - self.0.flags.set(flags); - self.0.dispatch_task.register(waker); - } - - #[inline] - /// Attempts to decode a frame from the read buffer. - pub fn decode( - &self, - codec: &U, - ) -> Result::Item>, ::Error> - where - U: Decoder, - { - let mut buf = self.0.get_read_buf(); - let result = codec.decode(&mut buf); - self.0.release_read_buf(buf); - result - } - - /// Get mut access to read buffer - pub fn with_buf(&self, f: F) -> R - where - F: FnOnce(&mut BytesMut) -> R, - { - let mut buf = self.0.get_read_buf(); - let res = f(&mut buf); - self.0.release_read_buf(buf); - res - } -} - -/// OnDisconnect future resolves when socket get disconnected -#[must_use = "OnDisconnect do nothing unless polled"] -pub struct OnDisconnect { - token: usize, - inner: Rc, -} - -impl OnDisconnect { - fn new(inner: Rc, disconnected: bool) -> Self { - let token = inner.on_disconnect.borrow_mut().insert(if disconnected { - Some(LocalWaker::default()) - } else { - None - }); - Self { token, inner } - } - - pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { - let mut on_disconnect = self.inner.on_disconnect.borrow_mut(); - - let inner = unsafe { on_disconnect.get_unchecked_mut(self.token) }; - if inner.is_none() { - let waker = LocalWaker::default(); - waker.register(cx.waker()); - *inner = Some(waker); - } else if !inner.as_mut().unwrap().register(cx.waker()) { - return Poll::Ready(()); - } - Poll::Pending - } -} - -impl Clone for OnDisconnect { - fn clone(&self) -> Self { - let token = self.inner.on_disconnect.borrow_mut().insert(None); - OnDisconnect { - token, - inner: self.inner.clone(), - } - } -} - -impl Future for OnDisconnect { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.get_mut().poll_ready(cx) - } -} - -impl Drop for OnDisconnect { - fn drop(&mut self) { - self.inner.on_disconnect.borrow_mut().remove(self.token); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{codec::BytesCodec, testing::Io, util::lazy, util::Bytes}; - - const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n"; - const TEXT: &str = "GET /test HTTP/1\r\n\r\n"; - - #[crate::rt_test] - async fn test_utils() { - let (client, mut server) = Io::create(); - client.remote_buffer_cap(1024); - client.write(TEXT); - - let state = State::new(); - assert!(!state.read().is_full()); - assert!(!state.write().is_full()); - - let msg = state.next(&mut server, &BytesCodec).await.unwrap().unwrap(); - assert_eq!(msg, Bytes::from_static(BIN)); - - let res = - poll_fn(|cx| Poll::Ready(state.poll_next(&mut server, &BytesCodec, cx))) - .await; - assert!(res.is_pending()); - client.write(TEXT); - let res = - poll_fn(|cx| Poll::Ready(state.poll_next(&mut server, &BytesCodec, cx))) - .await; - if let Poll::Ready(msg) = res { - assert_eq!(msg.unwrap().unwrap(), Bytes::from_static(BIN)); - } - - client.read_error(io::Error::new(io::ErrorKind::Other, "err")); - let msg = state.next(&mut server, &BytesCodec).await; - assert!(msg.is_err()); - state.flags().contains(Flags::IO_ERR); - state.flags().contains(Flags::DSP_STOP); - state.remove_flags(Flags::IO_ERR | Flags::DSP_STOP); - - client.read_error(io::Error::new(io::ErrorKind::Other, "err")); - let res = - poll_fn(|cx| Poll::Ready(state.poll_next(&mut server, &BytesCodec, cx))) - .await; - if let Poll::Ready(msg) = res { - assert!(msg.is_err()); - state.flags().contains(Flags::IO_ERR); - state.flags().contains(Flags::DSP_STOP); - state.remove_flags(Flags::IO_ERR | Flags::DSP_STOP); - } - - state - .send(&mut server, &BytesCodec, Bytes::from_static(b"test")) - .await - .unwrap(); - let buf = client.read().await.unwrap(); - assert_eq!(buf, Bytes::from_static(b"test")); - - client.write_error(io::Error::new(io::ErrorKind::Other, "err")); - let res = state - .send(&mut server, &BytesCodec, Bytes::from_static(b"test")) - .await; - assert!(res.is_err()); - state.flags().contains(Flags::IO_ERR); - state.flags().contains(Flags::DSP_STOP); - state.remove_flags(Flags::IO_ERR | Flags::DSP_STOP); - - state.remove_flags(Flags::IO_ERR | Flags::DSP_STOP); - state.force_close(); - state.flags().contains(Flags::DSP_STOP); - state.flags().contains(Flags::IO_SHUTDOWN); - } - - #[crate::rt_test] - async fn test_on_disconnect() { - let state = State::new(); - let mut waiter = state.on_disconnect(); - assert_eq!( - lazy(|cx| Pin::new(&mut waiter).poll(cx)).await, - Poll::Pending - ); - let mut waiter2 = waiter.clone(); - assert_eq!( - lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await, - Poll::Pending - ); - state.set_wr_shutdown_complete(); - assert_eq!(waiter.await, ()); - assert_eq!(waiter2.await, ()); - - let mut waiter = state.on_disconnect(); - assert_eq!( - lazy(|cx| Pin::new(&mut waiter).poll(cx)).await, - Poll::Ready(()) - ); - - let state = State::new(); - let mut waiter = state.on_disconnect(); - assert_eq!( - lazy(|cx| Pin::new(&mut waiter).poll(cx)).await, - Poll::Pending - ); - state.set_io_error(None); - assert_eq!(waiter.await, ()); - } -} diff --git a/ntex/src/framed/time.rs b/ntex/src/framed/time.rs deleted file mode 100644 index e1105546..00000000 --- a/ntex/src/framed/time.rs +++ /dev/null @@ -1,115 +0,0 @@ -use std::{cell::RefCell, collections::BTreeMap, rc::Rc, time::Instant}; - -use crate::framed::State; -use crate::time::{now, sleep, Millis}; -use crate::util::HashSet; - -pub struct Timer(Rc>); - -struct Inner { - resolution: Millis, - current: Option, - notifications: BTreeMap>, -} - -impl Inner { - fn new(resolution: Millis) -> Self { - Inner { - resolution, - current: None, - notifications: BTreeMap::default(), - } - } - - fn unregister(&mut self, expire: Instant, state: &State) { - if let Some(ref mut states) = self.notifications.get_mut(&expire) { - states.remove(state); - if states.is_empty() { - self.notifications.remove(&expire); - } - } - } -} - -impl Clone for Timer { - fn clone(&self) -> Self { - Timer(self.0.clone()) - } -} - -impl Default for Timer { - fn default() -> Self { - Timer::new(Millis::ONE_SEC) - } -} - -impl Timer { - /// Create new timer with resolution in milliseconds - pub fn new(resolution: Millis) -> Timer { - Timer(Rc::new(RefCell::new(Inner::new(resolution)))) - } - - pub fn register(&self, expire: Instant, previous: Instant, state: &State) { - { - let mut inner = self.0.borrow_mut(); - - inner.unregister(previous, state); - inner - .notifications - .entry(expire) - .or_insert_with(HashSet::default) - .insert(state.clone()); - } - - let _ = self.now(); - } - - pub fn unregister(&self, expire: Instant, state: &State) { - self.0.borrow_mut().unregister(expire, state); - } - - /// Get current time. This function has to be called from - /// future's poll method, otherwise it panics. - pub fn now(&self) -> Instant { - let cur = self.0.borrow().current; - if let Some(cur) = cur { - cur - } else { - let now_val = now(); - let inner = self.0.clone(); - let interval = { - let mut b = inner.borrow_mut(); - b.current = Some(now_val); - b.resolution - }; - - crate::rt::spawn(async move { - sleep(interval).await; - let empty = { - let mut i = inner.borrow_mut(); - let now = i.current.take().unwrap_or_else(now); - - // notify io dispatcher - while let Some(key) = i.notifications.keys().next() { - let key = *key; - if key <= now { - for st in i.notifications.remove(&key).unwrap() { - st.keepalive_timeout(); - } - } else { - break; - } - } - i.notifications.is_empty() - }; - - // extra tick - if !empty { - let _ = Timer(inner).now(); - } - }); - - now_val - } - } -} diff --git a/ntex/src/framed/write.rs b/ntex/src/framed/write.rs deleted file mode 100644 index 0baa5ec6..00000000 --- a/ntex/src/framed/write.rs +++ /dev/null @@ -1,166 +0,0 @@ -use std::task::{Context, Poll}; -use std::{cell::RefCell, future::Future, pin::Pin, rc::Rc}; - -use crate::codec::{AsyncRead, AsyncWrite, ReadBuf}; -use crate::framed::State; -use crate::time::{sleep, Sleep}; - -#[derive(Debug)] -enum IoWriteState { - Processing, - Shutdown(Option, Shutdown), -} - -#[derive(Debug)] -enum Shutdown { - None, - Flushed, - Stopping, -} - -/// Write io task -pub struct WriteTask -where - T: AsyncRead + AsyncWrite + Unpin, -{ - st: IoWriteState, - io: Rc>, - state: State, -} - -impl WriteTask -where - T: AsyncRead + AsyncWrite + Unpin, -{ - /// Create new write io task - pub fn new(io: Rc>, state: State) -> Self { - Self { - io, - state, - st: IoWriteState::Processing, - } - } - - /// Shutdown io stream - pub fn shutdown(io: Rc>, state: State) -> Self { - let disconnect_timeout = state.get_disconnect_timeout(); - let st = IoWriteState::Shutdown(disconnect_timeout.map(sleep), Shutdown::None); - - Self { st, io, state } - } -} - -impl Future for WriteTask -where - T: AsyncRead + AsyncWrite + Unpin, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut().get_mut(); - - // IO error occured - if this.state.is_io_err() { - log::trace!("write io is closed"); - return Poll::Ready(()); - } else if this.state.is_io_stop() { - self.state.wake_dispatcher(); - return Poll::Ready(()); - } - - match this.st { - IoWriteState::Processing => { - if this.state.is_io_shutdown() { - log::trace!("write task is instructed to shutdown"); - - let disconnect_timeout = this.state.get_disconnect_timeout(); - this.st = IoWriteState::Shutdown( - disconnect_timeout.map(sleep), - Shutdown::None, - ); - return self.poll(cx); - } - - // flush framed instance - match this.state.flush_io(&mut *this.io.borrow_mut(), cx) { - Poll::Pending | Poll::Ready(true) => Poll::Pending, - Poll::Ready(false) => Poll::Ready(()), - } - } - IoWriteState::Shutdown(ref mut delay, ref mut st) => { - // close WRITE side and wait for disconnect on read side. - // use disconnect timeout, otherwise it could hang forever. - loop { - match st { - Shutdown::None => { - // flush write buffer - let result = - this.state.flush_io(&mut *this.io.borrow_mut(), cx); - match result { - Poll::Ready(true) => { - *st = Shutdown::Flushed; - continue; - } - Poll::Ready(false) => { - this.state.set_wr_shutdown_complete(); - log::trace!( - "write task is closed with err during flush" - ); - return Poll::Ready(()); - } - _ => (), - } - } - Shutdown::Flushed => { - // shutdown WRITE side - match Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx) - { - Poll::Ready(Ok(_)) => { - *st = Shutdown::Stopping; - continue; - } - Poll::Ready(Err(_)) => { - this.state.set_wr_shutdown_complete(); - log::trace!( - "write task is closed with err during shutdown" - ); - return Poll::Ready(()); - } - _ => (), - } - } - Shutdown::Stopping => { - // read until 0 or err - let mut buf = [0u8; 512]; - let mut io = this.io.borrow_mut(); - loop { - let mut read_buf = ReadBuf::new(&mut buf); - match Pin::new(&mut *io).poll_read(cx, &mut read_buf) { - Poll::Ready(Err(_)) | Poll::Ready(Ok(_)) - if read_buf.filled().is_empty() => - { - this.state.set_wr_shutdown_complete(); - log::trace!("write task is stopped"); - return Poll::Ready(()); - } - Poll::Pending => break, - _ => (), - } - } - } - } - - // disconnect timeout - if let Some(ref delay) = delay { - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } - } - this.state.set_wr_shutdown_complete(); - log::trace!("write task is stopped after delay"); - return Poll::Ready(()); - } - } - } - } -} diff --git a/ntex/src/http/builder.rs b/ntex/src/http/builder.rs index 23b9d84f..6bb29d6c 100644 --- a/ntex/src/http/builder.rs +++ b/ntex/src/http/builder.rs @@ -1,11 +1,11 @@ use std::{cell::RefCell, error::Error, fmt, marker::PhantomData, rc::Rc}; -use crate::framed::State; use crate::http::body::MessageBody; use crate::http::config::{KeepAlive, OnRequest, ServiceConfig}; use crate::http::error::ResponseError; use crate::http::h1::{Codec, ExpectHandler, H1Service, UpgradeHandler}; -use crate::http::h2::H2Service; +use crate::io::{Filter, Io, IoRef}; +// use crate::http::h2::H2Service; use crate::http::helpers::{Data, DataFactory}; use crate::http::request::Request; use crate::http::response::Response; @@ -18,7 +18,7 @@ use crate::util::PoolId; /// /// This type can be used to construct an instance of `http service` through a /// builder-like pattern. -pub struct HttpServiceBuilder> { +pub struct HttpServiceBuilder> { keep_alive: KeepAlive, client_timeout: Millis, client_disconnect: Seconds, @@ -26,12 +26,11 @@ pub struct HttpServiceBuilder> { pool: PoolId, expect: X, upgrade: Option, - on_connect: Option Box>>, - on_request: Option>, - _t: PhantomData<(T, S)>, + on_request: Option, + _t: PhantomData<(F, S)>, } -impl HttpServiceBuilder> { +impl HttpServiceBuilder> { /// Create instance of `ServiceConfigBuilder` pub fn new() -> Self { HttpServiceBuilder { @@ -42,15 +41,15 @@ impl HttpServiceBuilder> { pool: PoolId::P1, expect: ExpectHandler, upgrade: None, - on_connect: None, on_request: None, _t: PhantomData, } } } -impl HttpServiceBuilder +impl HttpServiceBuilder where + F: Filter + 'static, S: ServiceFactory, S::Error: ResponseError + 'static, S::InitError: fmt::Debug, @@ -61,7 +60,7 @@ where X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, - U: ServiceFactory, + U: ServiceFactory, Codec), Response = ()>, U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, @@ -122,29 +121,14 @@ where self } - #[doc(hidden)] - #[deprecated(since = "0.4.12", note = "Use memory pool config")] - #[inline] - /// Set read/write buffer params - /// - /// By default read buffer is 8kb, write buffer is 8kb - pub fn buffer_params( - self, - _max_read_buf_size: u16, - _max_write_buf_size: u16, - _min_buf_size: u16, - ) -> Self { - self - } - /// Provide service for `EXPECT: 100-Continue` support. /// /// Service get called with request that contains `EXPECT` header. /// Service must return request in case of success, in that case /// request will be forwarded to main service. - pub fn expect(self, expect: F) -> HttpServiceBuilder + pub fn expect(self, expect: XF) -> HttpServiceBuilder where - F: IntoServiceFactory, + XF: IntoServiceFactory, X1: ServiceFactory, X1::Error: ResponseError + 'static, X1::InitError: fmt::Debug, @@ -159,7 +143,6 @@ where pool: self.pool, expect: expect.into_factory(), upgrade: self.upgrade, - on_connect: self.on_connect, on_request: self.on_request, _t: PhantomData, } @@ -169,12 +152,12 @@ where /// /// If service is provided then normal requests handling get halted /// and this service get called with original request and framed object. - pub fn upgrade(self, upgrade: F) -> HttpServiceBuilder + pub fn upgrade(self, upgrade: UF) -> HttpServiceBuilder where - F: IntoServiceFactory, + UF: IntoServiceFactory, U1: ServiceFactory< Config = (), - Request = (Request, T, State, Codec), + Request = (Request, Io, Codec), Response = (), >, U1::Error: fmt::Display + Error + 'static, @@ -190,46 +173,29 @@ where pool: self.pool, expect: self.expect, upgrade: Some(upgrade.into_factory()), - on_connect: self.on_connect, on_request: self.on_request, _t: PhantomData, } } - /// Set on-connect callback. - /// - /// It get called once per connection and result of the call - /// get stored to the request's extensions. - pub fn on_connect(mut self, f: F) -> Self - where - F: Fn(&T) -> I + 'static, - I: Clone + 'static, - { - self.on_connect = Some(Rc::new(move |io| Box::new(Data(f(io))))); - self - } - /// Set req request callback. /// /// It get called once per request. - pub fn on_request(mut self, f: F) -> Self + pub fn on_request(mut self, f: FR) -> Self where - F: IntoService, - Filter: Service< - Request = (Request, Rc>), - Response = Request, - Error = Response, - > + 'static, + FR: IntoService, + R: Service + + 'static, { self.on_request = Some(boxed::service(f.into_service())); self } /// Finish service configuration and create *http service* for HTTP/1 protocol. - pub fn h1(self, service: F) -> H1Service + pub fn h1(self, service: SF) -> H1Service where B: MessageBody, - F: IntoServiceFactory, + SF: IntoServiceFactory, S::Error: ResponseError, S::InitError: fmt::Debug, S::Response: Into>, @@ -245,15 +211,15 @@ where H1Service::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) - .on_connect(self.on_connect) .on_request(self.on_request) } + // pub fn h2(self, service: F) -> H2Service /// Finish service configuration and create *http service* for HTTP/2 protocol. - pub fn h2(self, service: F) -> H2Service + pub fn h2(self, service: SF) -> H1Service where B: MessageBody + 'static, - F: IntoServiceFactory, + SF: IntoServiceFactory, S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, @@ -266,14 +232,19 @@ where self.handshake_timeout, self.pool, ); - H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect) + + // H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect) + H1Service::with_config(cfg, service.into_factory()) + .expect(self.expect) + .upgrade(self.upgrade) + .on_request(self.on_request) } /// Finish service configuration and create `HttpService` instance. - pub fn finish(self, service: F) -> HttpService + pub fn finish(self, service: SF) -> HttpService where B: MessageBody + 'static, - F: IntoServiceFactory, + SF: IntoServiceFactory, S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, @@ -290,7 +261,6 @@ where HttpService::with_config(cfg, service.into_factory()) .expect(self.expect) .upgrade(self.upgrade) - .on_connect(self.on_connect) .on_request(self.on_request) } } diff --git a/ntex/src/http/client/builder.rs b/ntex/src/http/client/builder.rs index b1b47f82..0c9b480e 100644 --- a/ntex/src/http/client/builder.rs +++ b/ntex/src/http/client/builder.rs @@ -42,10 +42,8 @@ impl ClientBuilder { /// Use custom connector service. pub fn connector(mut self, connector: T) -> Self where - T: Service + 'static, - T::Response: Connection, - ::Future: 'static, - T::Future: 'static, + T: Service + + 'static, { self.config.connector = Box::new(ConnectorWrapper(connector)); self diff --git a/ntex/src/http/client/connect.rs b/ntex/src/http/client/connect.rs index b642fb4e..b206f41b 100644 --- a/ntex/src/http/client/connect.rs +++ b/ntex/src/http/client/connect.rs @@ -1,15 +1,23 @@ use std::{fmt, future::Future, io, net, pin::Pin, task::Context, task::Poll}; -use crate::codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; use crate::http::body::Body; use crate::http::h1::ClientCodec; use crate::http::{RequestHeadType, ResponseHead}; +use crate::io::IoBoxed; use crate::Service; use super::error::{ConnectError, SendRequestError}; use super::response::ClientResponse; use super::{Connect as ClientConnect, Connection}; +pub(crate) type TunnelFuture = Pin< + Box< + dyn Future< + Output = Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError>, + >, + >, +>; + pub(super) struct ConnectorWrapper(pub(crate) T); pub(super) trait Connect { @@ -25,25 +33,12 @@ pub(super) trait Connect { &self, head: RequestHeadType, addr: Option, - ) -> Pin< - Box< - dyn Future< - Output = Result< - (ResponseHead, Framed), - SendRequestError, - >, - >, - >, - >; + ) -> TunnelFuture; } impl Connect for ConnectorWrapper where - T: Service, - T::Response: Connection, - ::Io: 'static, - ::Future: 'static, - ::TunnelFuture: 'static, + T: Service, T::Future: 'static, { fn send_request( @@ -73,16 +68,7 @@ where &self, head: RequestHeadType, addr: Option, - ) -> Pin< - Box< - dyn Future< - Output = Result< - (ResponseHead, Framed), - SendRequestError, - >, - >, - >, - > { + ) -> TunnelFuture { // connect to the host let fut = self.0.call(ClientConnect { uri: head.as_ref().uri.clone(), @@ -93,69 +79,7 @@ where let connection = fut.await?; // send request - let (head, framed) = connection.open_tunnel(head).await?; - - let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); - Ok((head, framed)) + connection.open_tunnel(head).await }) } } - -trait AsyncSocket { - fn as_read(&self) -> &(dyn AsyncRead + Unpin); - fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin); - fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin); -} - -struct Socket(T); - -impl AsyncSocket for Socket { - fn as_read(&self) -> &(dyn AsyncRead + Unpin) { - &self.0 - } - fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) { - &mut self.0 - } - fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) { - &mut self.0 - } -} - -pub struct BoxedSocket(Box); - -impl fmt::Debug for BoxedSocket { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "BoxedSocket") - } -} - -impl AsyncRead for BoxedSocket { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf) - } -} - -impl AsyncWrite for BoxedSocket { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(self.get_mut().0.as_write()).poll_flush(cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx) - } -} diff --git a/ntex/src/http/client/connection.rs b/ntex/src/http/client/connection.rs index f7910d72..bb3b8ee1 100644 --- a/ntex/src/http/client/connection.rs +++ b/ntex/src/http/client/connection.rs @@ -8,79 +8,43 @@ use crate::http::h1::ClientCodec; use crate::http::message::{RequestHeadType, ResponseHead}; use crate::http::payload::Payload; use crate::http::Protocol; +use crate::io::IoBoxed; use crate::util::{Bytes, Either, Ready}; use super::error::SendRequestError; use super::pool::Acquired; use super::{h1proto, h2proto}; -pub(super) enum ConnectionType { - H1(Io), +pub(super) enum ConnectionType { + H1(IoBoxed), H2(SendRequest), } -pub trait Connection { - type Io: AsyncRead + AsyncWrite + Unpin; - type Future: Future>; - - fn protocol(&self) -> Protocol; - - /// Send request and body - fn send_request>( - self, - head: H, - body: B, - ) -> Self::Future; - - type TunnelFuture: Future< - Output = Result<(ResponseHead, Framed), SendRequestError>, - >; - - /// Send request, returns Response and Framed - fn open_tunnel>(self, head: H) -> Self::TunnelFuture; -} - -pub(super) trait ConnectionLifetime: - AsyncRead + AsyncWrite + Unpin + 'static -{ - /// Close connection - fn close(&mut self); - - /// Release connection to the connection pool - fn release(&mut self); -} - #[doc(hidden)] /// HTTP client connection -pub(super) struct IoConnection { - io: Option>, +pub struct Connection { + io: Option, created: time::Instant, - pool: Option>, + pool: Option, } -impl fmt::Debug for IoConnection -where - T: fmt::Debug, -{ +impl fmt::Debug for Connection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.io { - Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io), + Some(ConnectionType::H1(_)) => write!(f, "H1Connection"), Some(ConnectionType::H2(_)) => write!(f, "H2Connection"), None => write!(f, "Connection(Empty)"), } } } -impl IoConnection -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl Connection { pub(super) fn new( - io: ConnectionType, + io: ConnectionType, created: time::Instant, - pool: Option>, + pool: Option, ) -> Self { - IoConnection { + Self { pool, created, io: Some(io), @@ -97,20 +61,11 @@ where } } - pub(super) fn into_inner(self) -> (ConnectionType, time::Instant) { + pub(super) fn into_inner(self) -> (ConnectionType, time::Instant) { (self.io.unwrap(), self.created) } -} -impl Connection for IoConnection -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - type Io = T; - type Future = - Pin>>>; - - fn protocol(&self) -> Protocol { + pub fn protocol(&self) -> Protocol { match self.io { Some(ConnectionType::H1(_)) => Protocol::Http1, Some(ConnectionType::H2(_)) => Protocol::Http2, @@ -118,58 +73,42 @@ where } } - fn send_request>( + pub(super) async fn send_request< + B: MessageBody + 'static, + H: Into, + >( mut self, head: H, body: B, - ) -> Self::Future { + ) -> Result<(ResponseHead, Payload), SendRequestError> { match self.io.take().unwrap() { - ConnectionType::H1(io) => Box::pin(h1proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), - ConnectionType::H2(io) => Box::pin(h2proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), + ConnectionType::H1(io) => { + h1proto::send_request(io, head.into(), body, self.created, self.pool) + .await + } + ConnectionType::H2(io) => { + h2proto::send_request(io, head.into(), body, self.created, self.pool) + .await + } } } - type TunnelFuture = Either< - Pin< - Box< - dyn Future< - Output = Result< - (ResponseHead, Framed), - SendRequestError, - >, - >, - >, - >, - Ready<(ResponseHead, Framed), SendRequestError>, - >; - /// Send request, returns Response and Framed - fn open_tunnel>(mut self, head: H) -> Self::TunnelFuture { + pub(super) async fn open_tunnel>( + mut self, + head: H, + ) -> Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError> { match self.io.take().unwrap() { - ConnectionType::H1(io) => { - Either::Left(Box::pin(h1proto::open_tunnel(io, head.into()))) - } + ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await, ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { - pool.release(IoConnection::new( + pool.release(Connection::new( ConnectionType::H2(io), self.created, None, )); } - Either::Right(Ready::Err(SendRequestError::TunnelNotSupported)) + Err(SendRequestError::TunnelNotSupported) } } } diff --git a/ntex/src/http/client/connector.rs b/ntex/src/http/client/connector.rs index eec59adb..2e3e05c0 100644 --- a/ntex/src/http/client/connector.rs +++ b/ntex/src/http/client/connector.rs @@ -1,8 +1,8 @@ use std::{rc::Rc, task::Context, task::Poll, time::Duration}; -use crate::codec::{AsyncRead, AsyncWrite}; use crate::connect::{Connect as TcpConnect, Connector as TcpConnector}; use crate::http::{Protocol, Uri}; +use crate::io::{Filter, Io, IoBoxed}; use crate::service::{apply_fn, boxed, Service}; use crate::time::{Millis, Seconds}; use crate::util::timeout::{TimeoutError, TimeoutService}; @@ -16,13 +16,12 @@ use super::Connect; #[cfg(feature = "openssl")] use crate::connect::openssl::SslConnector as OpensslConnector; -#[cfg(feature = "rustls")] -use crate::connect::rustls::ClientConfig; -#[cfg(feature = "rustls")] -use std::sync::Arc; +//#[cfg(feature = "rustls")] +//use crate::connect::rustls::ClientConfig; +//#[cfg(feature = "rustls")] +//use std::sync::Arc; -type BoxedConnector = - boxed::BoxService, (Box, Protocol), ConnectError>; +type BoxedConnector = boxed::BoxService, IoBoxed, ConnectError>; /// Manages http client network connectivity. /// @@ -47,9 +46,6 @@ pub struct Connector { ssl_connector: Option, } -trait Io: AsyncRead + AsyncWrite + Unpin {} -impl Io for T {} - impl Default for Connector { fn default() -> Self { Connector::new() @@ -61,7 +57,7 @@ impl Connector { let conn = Connector { connector: boxed::service( TcpConnector::new() - .map(|io| (Box::new(io) as Box, Protocol::Http1)) + .map(|io| io.into_boxed()) .map_err(ConnectError::from), ), ssl_connector: None, @@ -82,28 +78,28 @@ impl Connector { .map_err(|e| error!("Cannot set ALPN protocol: {:?}", e)); conn.openssl(ssl.build()) } - #[cfg(all(not(feature = "openssl"), feature = "rustls"))] - { - use rust_tls::{OwnedTrustAnchor, RootCertStore}; + // #[cfg(all(not(feature = "openssl"), feature = "rustls"))] + // { + // use rust_tls::{OwnedTrustAnchor, RootCertStore}; - let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - let mut cert_store = RootCertStore::empty(); - cert_store.add_server_trust_anchors( - webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }), - ); - let mut config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(cert_store) - .with_no_client_auth(); - config.alpn_protocols = protos; - conn.rustls(Arc::new(config)) - } + // let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + // let mut cert_store = RootCertStore::empty(); + // cert_store.add_server_trust_anchors( + // webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + // OwnedTrustAnchor::from_subject_spki_name_constraints( + // ta.subject, + // ta.spki, + // ta.name_constraints, + // ) + // }), + // ); + // let mut config = ClientConfig::builder() + // .with_safe_defaults() + // .with_root_certificates(cert_store) + // .with_no_client_auth(); + // config.alpn_protocols = protos; + // conn.rustls(Arc::new(config)) + // } #[cfg(not(any(feature = "openssl", feature = "rustls")))] { conn @@ -126,41 +122,29 @@ impl Connector { pub fn openssl(self, connector: OpensslConnector) -> Self { use crate::connect::openssl::OpensslConnector; - const H2: &[u8] = b"h2"; - self.secure_connector(OpensslConnector::new(connector).map(|sock| { - let h2 = sock - .ssl() - .selected_alpn_protocol() - .map(|protos| protos.windows(2).any(|w| w == H2)) - .unwrap_or(false); - if h2 { - (sock, Protocol::Http2) - } else { - (sock, Protocol::Http1) - } - })) + self.secure_connector(OpensslConnector::new(connector)) } - #[cfg(feature = "rustls")] - /// Use rustls connector for secured connections. - pub fn rustls(self, connector: Arc) -> Self { - use crate::connect::rustls::RustlsConnector; + // #[cfg(feature = "rustls")] + // /// Use rustls connector for secured connections. + // pub fn rustls(self, connector: Arc) -> Self { + // use crate::connect::rustls::RustlsConnector; - const H2: &[u8] = b"h2"; - self.secure_connector(RustlsConnector::new(connector).map(|sock| { - let h2 = sock - .get_ref() - .1 - .alpn_protocol() - .map(|protos| protos.windows(2).any(|w| w == H2)) - .unwrap_or(false); - if h2 { - (Box::new(sock) as Box, Protocol::Http2) - } else { - (Box::new(sock) as Box, Protocol::Http1) - } - })) - } + // const H2: &[u8] = b"h2"; + // self.secure_connector(RustlsConnector::new(connector).map(|sock| { + // let h2 = sock + // .get_ref() + // .1 + // .alpn_protocol() + // .map(|protos| protos.windows(2).any(|w| w == H2)) + // .unwrap_or(false); + // if h2 { + // (Box::new(sock) as Box, Protocol::Http2) + // } else { + // (Box::new(sock) as Box, Protocol::Http1) + // } + // })) + // } /// Set total number of simultaneous connections per type of scheme. /// @@ -206,36 +190,36 @@ impl Connector { } /// Use custom connector to open un-secured connections. - pub fn connector(mut self, connector: T) -> Self + pub fn connector(mut self, connector: T) -> Self where - U: AsyncRead + AsyncWrite + Unpin + 'static, T: Service< Request = TcpConnect, - Response = (U, Protocol), + Response = Io, Error = crate::connect::ConnectError, > + 'static, + F: Filter, { self.connector = boxed::service( connector - .map(|(io, proto)| (Box::new(io) as Box, proto)) + .map(|io| io.into_boxed()) .map_err(ConnectError::from), ); self } /// Use custom connector to open secure connections. - pub fn secure_connector(mut self, connector: T) -> Self + pub fn secure_connector(mut self, connector: T) -> Self where - U: AsyncRead + AsyncWrite + Unpin + 'static, T: Service< Request = TcpConnect, - Response = (U, Protocol), + Response = Io, Error = crate::connect::ConnectError, > + 'static, + F: Filter, { self.ssl_connector = Some(boxed::service( connector - .map(|(io, proto)| (Box::new(io) as Box, proto)) + .map(|io| io.into_boxed()) .map_err(ConnectError::from), )); self @@ -246,12 +230,13 @@ impl Connector { /// its combinator chain. pub fn finish( self, - ) -> impl Service - + Clone { - let tcp_service = connector(self.connector, self.timeout); + ) -> impl Service + Clone + { + let tcp_service = + connector(self.connector, self.timeout, self.disconnect_timeout); let ssl_pool = if let Some(ssl_connector) = self.ssl_connector { - let srv = connector(ssl_connector, self.timeout); + let srv = connector(ssl_connector, self.timeout, self.disconnect_timeout); Some(ConnectionPool::new( srv, self.conn_lifetime, @@ -279,9 +264,10 @@ impl Connector { fn connector( connector: BoxedConnector, timeout: Millis, + disconnect_timeout: Millis, ) -> impl Service< Request = Connect, - Response = (Box, Protocol), + Response = IoBoxed, Error = ConnectError, Future = impl Unpin, > + Unpin { @@ -290,6 +276,10 @@ fn connector( apply_fn(connector, |msg: Connect, srv| { srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) }) + .map(move |io: IoBoxed| { + io.set_disconnect_timeout(disconnect_timeout); + io + }) .map_err(ConnectError::from), ) .map_err(|e| match e { @@ -298,28 +288,25 @@ fn connector( }) } -type Pool = ConnectionPool>; - struct InnerConnector { - tcp_pool: Pool, - ssl_pool: Option>, + tcp_pool: ConnectionPool, + ssl_pool: Option>, } impl Service for InnerConnector where - T: Service< - Request = Connect, - Response = (Box, Protocol), - Error = ConnectError, - > + Unpin + T: Service + + Unpin + 'static, T::Future: Unpin, { type Request = Connect; - type Response = as Service>::Response; + type Response = as Service>::Response; type Error = ConnectError; - type Future = - Either< as Service>::Future, Ready>; + type Future = Either< + as Service>::Future, + Ready, + >; #[inline] fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { diff --git a/ntex/src/http/client/h1proto.rs b/ntex/src/http/client/h1proto.rs index a6b090e1..dceb0679 100644 --- a/ntex/src/http/client/h1proto.rs +++ b/ntex/src/http/client/h1proto.rs @@ -1,28 +1,27 @@ -use std::{io, io::Write, pin::Pin, task::Context, task::Poll, time}; +use std::{io, io::Write, pin::Pin, task::Context, task::Poll, time::Instant}; -use crate::codec::{AsyncRead, AsyncWrite, Framed, ReadBuf}; use crate::http::body::{BodySize, MessageBody}; use crate::http::error::PayloadError; use crate::http::h1; use crate::http::header::{HeaderMap, HeaderValue, HOST}; use crate::http::message::{RequestHeadType, ResponseHead}; use crate::http::payload::{Payload, PayloadStream}; +use crate::io::IoBoxed; use crate::util::{next, poll_fn, send, BufMut, Bytes, BytesMut}; use crate::{Sink, Stream}; -use super::connection::{ConnectionLifetime, ConnectionType, IoConnection}; +use super::connection::{Connection, ConnectionType}; use super::error::{ConnectError, SendRequestError}; use super::pool::Acquired; -pub(super) async fn send_request( - io: T, +pub(super) async fn send_request( + io: IoBoxed, mut head: RequestHeadType, body: B, - created: time::Instant, - pool: Option>, + created: Instant, + pool: Option, ) -> Result<(ResponseHead, Payload), SendRequestError> where - T: AsyncRead + AsyncWrite + Unpin + 'static, B: MessageBody, { // set request host header @@ -52,209 +51,130 @@ where } } - let io = H1Connection { - created, - pool, - io: Some(io), - }; + // let io = H1Connection { + // created, + // pool, + // io: Some(io), + // }; - // create Framed and send request - let mut framed = Framed::new(io, h1::ClientCodec::default()); - send(&mut framed, (head, body.size()).into()).await?; + // send request + let codec = h1::ClientCodec::default(); + io.send((head, body.size()).into(), &codec).await?; // send request body match body.size() { BodySize::None | BodySize::Empty | BodySize::Sized(0) => (), - _ => send_body(body, &mut framed).await?, + _ => { + send_body(body, &io, &codec).await?; + } }; // read response and init read body - let head = if let Some(result) = next(&mut framed).await { - result.map_err(SendRequestError::from)? + let head = if let Some(result) = io.next(&codec).await? { + result } else { return Err(SendRequestError::from(ConnectError::Disconnected)); }; - match framed.get_codec().message_type() { + match codec.message_type() { h1::MessageType::None => { - let force_close = !framed.get_codec().keepalive(); - release_connection(framed, force_close); + let force_close = !codec.keepalive(); + release_connection(io, force_close, created, pool); Ok((head, Payload::None)) } _ => { - let pl: PayloadStream = Box::pin(PlStream::new(framed)); + let pl: PayloadStream = Box::pin(PlStream::new(io, codec, created, pool)); Ok((head, pl.into())) } } } -pub(super) async fn open_tunnel( - io: T, +pub(super) async fn open_tunnel( + io: IoBoxed, head: RequestHeadType, -) -> Result<(ResponseHead, Framed), SendRequestError> -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +) -> Result<(ResponseHead, IoBoxed, h1::ClientCodec), SendRequestError> { // create Framed and send request - let mut framed = Framed::new(io, h1::ClientCodec::default()); - send(&mut framed, (head, BodySize::None).into()).await?; + let codec = h1::ClientCodec::default(); + io.send((head, BodySize::None).into(), &codec).await?; // read response - if let Some(result) = next(&mut framed).await { - let head = result.map_err(SendRequestError::from)?; - Ok((head, framed)) + if let Some(head) = io.next(&codec).await? { + Ok((head, io, codec)) } else { Err(SendRequestError::from(ConnectError::Disconnected)) } } /// send request body to the peer -pub(super) async fn send_body( +pub(super) async fn send_body( mut body: B, - framed: &mut Framed, + io: &IoBoxed, + codec: &h1::ClientCodec, ) -> Result<(), SendRequestError> where - I: ConnectionLifetime, B: MessageBody, { - let mut eof = false; - while !eof { - while !eof && !framed.is_write_buf_full() { - match poll_fn(|cx| body.poll_next_chunk(cx)).await { - Some(result) => { - framed.write(h1::Message::Chunk(Some(result?)))?; - } - None => { - eof = true; - framed.write(h1::Message::Chunk(None))?; + let wrt = io.write(); + + loop { + match poll_fn(|cx| body.poll_next_chunk(cx)).await { + Some(result) => { + if !wrt.encode(h1::Message::Chunk(Some(result?)), codec)? { + wrt.flush(false).await?; } } - } - - if !framed.is_write_buf_empty() { - poll_fn(|cx| match framed.flush(cx) { - Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), - Poll::Ready(Err(err)) => Poll::Ready(Err(err)), - Poll::Pending => { - if !framed.is_write_buf_full() { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - }) - .await?; + None => { + wrt.encode(h1::Message::Chunk(None), codec)?; + break; + } } } - - poll_fn(|cx| Pin::new(&mut *framed).poll_flush(cx)).await?; + wrt.flush(true).await?; Ok(()) } -#[doc(hidden)] -/// HTTP client connection -pub(super) struct H1Connection { - io: Option, - created: time::Instant, - pool: Option>, +pub(super) struct PlStream { + io: Option, + codec: h1::ClientPayloadCodec, + created: Instant, + pool: Option, } -impl ConnectionLifetime for H1Connection -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - /// Close connection - fn close(&mut self) { - if let Some(mut pool) = self.pool.take() { - if let Some(io) = self.io.take() { - pool.close(IoConnection::new( - ConnectionType::H1(io), - self.created, - None, - )); - } - } - } - - /// Release this connection to the connection pool - fn release(&mut self) { - if let Some(mut pool) = self.pool.take() { - if let Some(io) = self.io.take() { - pool.release(IoConnection::new( - ConnectionType::H1(io), - self.created, - None, - )); - } - } - } -} - -impl AsyncRead for H1Connection { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf) - } -} - -impl AsyncWrite for H1Connection { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf) - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(self.io.as_mut().unwrap()).poll_flush(cx) - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx) - } -} - -pub(super) struct PlStream { - framed: Option>, -} - -impl PlStream { - fn new(framed: Framed) -> Self { +impl PlStream { + fn new( + io: IoBoxed, + codec: h1::ClientCodec, + created: Instant, + pool: Option, + ) -> Self { PlStream { - framed: Some(framed.map_codec(|codec| codec.into_payload_codec())), + io: Some(io), + codec: codec.into_payload_codec(), + created, + pool, } } } -impl Stream for PlStream { +impl Stream for PlStream { type Item = Result; fn poll_next( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let this = self.get_mut(); + let mut this = self.as_mut(); - match this.framed.as_mut().unwrap().next_item(cx)? { + match this.io.as_ref().unwrap().poll_next(&this.codec, cx)? { Poll::Pending => Poll::Pending, Poll::Ready(Some(chunk)) => { if let Some(chunk) = chunk { Poll::Ready(Some(Ok(chunk))) } else { - let framed = this.framed.take().unwrap(); - let force_close = !framed.get_codec().keepalive(); - release_connection(framed, force_close); + let io = this.io.take().unwrap(); + let force_close = !this.codec.keepalive(); + release_connection(io, force_close, this.created, this.pool.take()); Poll::Ready(None) } } @@ -263,14 +183,17 @@ impl Stream for PlStream { } } -fn release_connection(framed: Framed, force_close: bool) -where - T: ConnectionLifetime, -{ - let mut parts = framed.into_parts(); - if !force_close && parts.read_buf.is_empty() && parts.write_buf.is_empty() { - parts.io.release() - } else { - parts.io.close() +fn release_connection( + io: IoBoxed, + force_close: bool, + created: Instant, + mut pool: Option, +) { + if force_close || io.is_closed() || io.read().with_buf(|buf| !buf.is_empty()) { + if let Some(mut pool) = pool.take() { + pool.close(Connection::new(ConnectionType::H1(io), created, None)); + } + } else if let Some(mut pool) = pool.take() { + pool.release(Connection::new(ConnectionType::H1(io), created, None)); } } diff --git a/ntex/src/http/client/h2proto.rs b/ntex/src/http/client/h2proto.rs index 7fb575ec..f50a523a 100644 --- a/ntex/src/http/client/h2proto.rs +++ b/ntex/src/http/client/h2proto.rs @@ -11,19 +11,18 @@ use crate::http::message::{RequestHeadType, ResponseHead}; use crate::http::payload::Payload; use crate::util::{poll_fn, Bytes}; -use super::connection::{ConnectionType, IoConnection}; +use super::connection::{Connection, ConnectionType}; use super::error::SendRequestError; use super::pool::Acquired; -pub(super) async fn send_request( +pub(super) async fn send_request( mut io: SendRequest, head: RequestHeadType, body: B, created: time::Instant, - pool: Option>, + pool: Option, ) -> Result<(ResponseHead, Payload), SendRequestError> where - T: AsyncRead + AsyncWrite + Unpin + 'static, B: MessageBody, { trace!("Sending client request: {:?} {:?}", head, body.size()); @@ -161,17 +160,17 @@ async fn send_body( } // release SendRequest object -fn release( +fn release( io: SendRequest, - pool: Option>, + pool: Option, created: time::Instant, close: bool, ) { if let Some(mut pool) = pool { if close { - pool.close(IoConnection::new(ConnectionType::H2(io), created, None)); + pool.close(Connection::new(ConnectionType::H2(io), created, None)); } else { - pool.release(IoConnection::new(ConnectionType::H2(io), created, None)); + pool.release(Connection::new(ConnectionType::H2(io), created, None)); } } } diff --git a/ntex/src/http/client/mod.rs b/ntex/src/http/client/mod.rs index c01c5e3d..50c8c9b2 100644 --- a/ntex/src/http/client/mod.rs +++ b/ntex/src/http/client/mod.rs @@ -34,7 +34,6 @@ mod test; pub mod ws; pub use self::builder::ClientBuilder; -pub use self::connect::BoxedSocket; pub use self::connection::Connection; pub use self::connector::Connector; pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; @@ -47,7 +46,7 @@ use crate::http::error::HttpError; use crate::http::{HeaderMap, Method, RequestHead, Uri}; use crate::time::Millis; -use self::connect::{Connect as InnerConnect, ConnectorWrapper}; +use self::connect::{Connect as HttpConnect, ConnectorWrapper}; #[derive(Clone)] pub struct Connect { @@ -76,7 +75,7 @@ pub struct Connect { pub struct Client(Rc); pub(self) struct ClientConfig { - pub(self) connector: Box, + pub(self) connector: Box, pub(self) headers: HeaderMap, pub(self) timeout: Millis, } diff --git a/ntex/src/http/client/pool.rs b/ntex/src/http/client/pool.rs index 1526d409..0c9b3842 100644 --- a/ntex/src/http/client/pool.rs +++ b/ntex/src/http/client/pool.rs @@ -2,19 +2,20 @@ use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use std::{cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc}; -use h2::client::{Builder, Connection, SendRequest}; +use h2::client::{Builder, Connection as H2Connection, SendRequest}; use http::uri::Authority; use crate::channel::pool; use crate::codec::{AsyncRead, AsyncWrite, ReadBuf}; use crate::http::Protocol; +use crate::io::IoBoxed; use crate::rt::spawn; use crate::service::Service; use crate::task::LocalWaker; use crate::time::{now, sleep, Millis, Sleep}; use crate::util::{poll_fn, Bytes, HashMap}; -use super::connection::{ConnectionType, IoConnection}; +use super::connection::{Connection, ConnectionType}; use super::error::ConnectError; use super::Connect; @@ -29,16 +30,15 @@ impl From for Key { } } -type Waiter = pool::Sender, ConnectError>>; -type WaiterReceiver = pool::Receiver, ConnectError>>; +type Waiter = pool::Sender>; +type WaiterReceiver = pool::Receiver>; /// Connections pool -pub(super) struct ConnectionPool(Rc, Rc>>); +pub(super) struct ConnectionPool(Rc, Rc>); -impl ConnectionPool +impl ConnectionPool where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + T: Service + Unpin + 'static, T::Future: Unpin, @@ -73,35 +73,27 @@ where } } -impl Drop for ConnectionPool -where - Io: 'static, -{ +impl Drop for ConnectionPool { fn drop(&mut self) { self.1.borrow().waker.wake(); } } -impl Clone for ConnectionPool -where - Io: 'static, -{ +impl Clone for ConnectionPool { fn clone(&self) -> Self { ConnectionPool(self.0.clone(), self.1.clone()) } } -impl Service for ConnectionPool +impl Service for ConnectionPool where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service - + 'static, + T: Service + 'static, T::Future: Unpin, { type Request = Connect; - type Response = IoConnection; + type Response = Connection; type Error = ConnectError; - type Future = Pin, ConnectError>>>>; + type Future = Pin>>>; #[inline] fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { @@ -127,11 +119,12 @@ where }; // acquire connection - match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await { + let result = inner.borrow_mut().acquire(&key); + match result { // use existing connection Acquire::Acquired(io, created) => { trace!("Use existing connection for {:?}", req.uri); - Ok(IoConnection::new( + Ok(Connection::new( io, created, Some(Acquired(key, Some(inner))), @@ -165,31 +158,31 @@ where } } -enum Acquire { - Acquired(ConnectionType, Instant), +enum Acquire { + Acquired(ConnectionType, Instant), Available, NotAvailable, } -struct AvailableConnection { - io: ConnectionType, +struct AvailableConnection { + io: ConnectionType, used: Instant, created: Instant, } -pub(super) struct Inner { +pub(super) struct Inner { conn_lifetime: Duration, conn_keep_alive: Duration, disconnect_timeout: Millis, limit: usize, acquired: usize, - available: HashMap>>, - waiters: VecDeque<(Key, Connect, Waiter)>, + available: HashMap>, + waiters: VecDeque<(Key, Connect, Waiter)>, waker: LocalWaker, - pool: pool::Pool, ConnectError>>, + pool: pool::Pool>, } -impl Inner { +impl Inner { fn reserve(&mut self) { self.acquired += 1; } @@ -199,12 +192,9 @@ impl Inner { } } -impl Inner -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl Inner { /// connection is not available, wait - fn wait_for(&mut self, connect: Connect) -> WaiterReceiver { + fn wait_for(&mut self, connect: Connect) -> WaiterReceiver { let (tx, rx) = self.pool.channel(); let key: Key = connect.uri.authority().unwrap().clone().into(); self.waiters.push_back((key, connect, tx)); @@ -226,7 +216,7 @@ where } } - fn acquire(&mut self, key: &Key, cx: &mut Context<'_>) -> Acquire { + fn acquire(&mut self, key: &Key) -> Acquire { self.cleanup(); // check limits @@ -249,19 +239,22 @@ where CloseConnection::spawn(io, self.disconnect_timeout); } } else { - let mut io = conn.io; - let mut buf = [0; 2]; - let mut read_buf = ReadBuf::new(&mut buf); - if let ConnectionType::H1(ref mut s) = io { - match Pin::new(s).poll_read(cx, &mut read_buf) { - Poll::Pending => (), - Poll::Ready(Ok(_)) if !read_buf.filled().is_empty() => { - if let ConnectionType::H1(io) = io { - CloseConnection::spawn(io, self.disconnect_timeout); - } - continue; + let io = conn.io; + if let ConnectionType::H1(ref s) = io { + if s.is_closed() { + continue; + } + let is_valid = s.read().with_buf(|buf| { + if buf.is_empty() || (buf.len() == 2 && &buf[..] == b"\r\n") + { + buf.clear(); + true + } else { + false } - _ => continue, + }); + if !is_valid { + continue; } } return Acquire::Acquired(io, conn.created); @@ -271,7 +264,7 @@ where Acquire::Available } - fn release_conn(&mut self, key: &Key, io: ConnectionType, created: Instant) { + fn release_conn(&mut self, key: &Key, io: ConnectionType, created: Instant) { self.acquired -= 1; self.available .entry(key.clone()) @@ -284,7 +277,7 @@ where self.check_availibility(); } - fn release_close(&mut self, io: ConnectionType) { + fn release_close(&mut self, io: ConnectionType) { self.acquired -= 1; if let ConnectionType::H1(io) = io { CloseConnection::spawn(io, self.disconnect_timeout); @@ -300,19 +293,14 @@ where } } -struct ConnectionPoolSupport -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ +struct ConnectionPoolSupport { connector: T, - inner: Rc>>, + inner: Rc>, } -impl Future for ConnectionPoolSupport +impl Future for ConnectionPoolSupport where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service - + Unpin, + T: Service + Unpin, T::Future: Unpin + 'static, { type Output = (); @@ -337,11 +325,11 @@ where }; let key = key.clone(); - match inner.acquire(&key, cx) { + match inner.acquire(&key) { Acquire::NotAvailable => break, Acquire::Acquired(io, created) => { let (key, _, tx) = inner.waiters.pop_front().unwrap(); - let _ = tx.send(Ok(IoConnection::new( + let _ = tx.send(Ok(Connection::new( io, created, Some(Acquired(key.clone(), Some(this.inner.clone()))), @@ -363,94 +351,43 @@ where } } -struct CloseConnection { - io: T, +struct CloseConnection { + io: IoBoxed, timeout: Option, shutdown: bool, } -impl CloseConnection -where - T: AsyncWrite + AsyncRead + Unpin + 'static, -{ - fn spawn(io: T, timeout: Millis) { - spawn(Self { - io, - shutdown: false, - timeout: timeout.map(sleep), +impl CloseConnection { + fn spawn(io: IoBoxed, timeout: Millis) { + spawn(async move { + io.shutdown().await; }); } } -impl Future for CloseConnection -where - T: AsyncWrite + AsyncRead + Unpin, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let mut this = self.as_mut(); - - // shutdown WRITE side - match Pin::new(&mut this.io).poll_shutdown(cx) { - Poll::Ready(Ok(())) => this.shutdown = true, - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(_)) => return Poll::Ready(()), - } - - // read until 0 or err - if let Some(ref timeout) = this.timeout { - match timeout.poll_elapsed(cx) { - Poll::Ready(_) => (), - Poll::Pending => { - let mut buf = [0u8; 512]; - let mut read_buf = ReadBuf::new(&mut buf); - loop { - match Pin::new(&mut this.io).poll_read(cx, &mut read_buf) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(_)) => return Poll::Ready(()), - Poll::Ready(Ok(_)) => { - if read_buf.filled().is_empty() { - return Poll::Ready(()); - } - continue; - } - } - } - } - } - } - Poll::Ready(()) - } -} - -struct OpenConnection -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ +struct OpenConnection { fut: F, h2: Option< Pin< Box< dyn Future< Output = Result< - (SendRequest, Connection), + (SendRequest, H2Connection), h2::Error, >, >, >, >, >, - tx: Option>, - guard: Option>, + tx: Option, + guard: Option, } -impl OpenConnection +impl OpenConnection where - F: Future> + Unpin + 'static, - Io: AsyncRead + AsyncWrite + Unpin + 'static, + F: Future> + Unpin + 'static, { - fn spawn(key: Key, tx: Waiter, inner: Rc>>, fut: F) { + fn spawn(key: Key, tx: Waiter, inner: Rc>, fut: F) { spawn(OpenConnection { fut, h2: None, @@ -463,10 +400,9 @@ where } } -impl Future for OpenConnection +impl Future for OpenConnection where - F: Future> + Unpin, - Io: AsyncRead + AsyncWrite + Unpin, + F: Future> + Unpin, { type Output = (); @@ -478,7 +414,7 @@ where return match Pin::new(h2).poll(cx) { Poll::Ready(Ok((snd, connection))) => { // h2 connection is ready - let conn = IoConnection::new( + let conn = Connection::new( ConnectionType::H2(snd), now(), Some(this.guard.take().unwrap().consume()), @@ -488,7 +424,7 @@ where conn.release() } spawn(async move { - let _ = connection.await; + // let _ = connection.await; }); Poll::Ready(()) } @@ -511,52 +447,43 @@ where } Poll::Ready(()) } - Poll::Ready(Ok((io, proto))) => { + Poll::Ready(Ok(io)) => { trace!("Connection is established"); // handle http1 proto - if proto == Protocol::Http1 { - let conn = IoConnection::new( - ConnectionType::H1(io), - now(), - Some(this.guard.take().unwrap().consume()), - ); - if let Err(Ok(conn)) = this.tx.take().unwrap().send(Ok(conn)) { - // waiter is gone, return connection to pool - conn.release() - } - Poll::Ready(()) - } else { - // init http2 handshake - this.h2 = Some(Box::pin(Builder::new().handshake(io))); - self.poll(cx) + //if proto == Protocol::Http1 { + let conn = Connection::new( + ConnectionType::H1(io), + now(), + Some(this.guard.take().unwrap().consume()), + ); + if let Err(Ok(conn)) = this.tx.take().unwrap().send(Ok(conn)) { + // waiter is gone, return connection to pool + conn.release() } + Poll::Ready(()) + // } else { + // init http2 handshake + // this.h2 = Some(Box::pin(Builder::new().handshake(io))); + // self.poll(cx) + //} } Poll::Pending => Poll::Pending, } } } -struct OpenGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ +struct OpenGuard { key: Key, - inner: Option>>>, + inner: Option>>, } -impl OpenGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn consume(mut self) -> Acquired { +impl OpenGuard { + fn consume(mut self) -> Acquired { Acquired(self.key.clone(), self.inner.take()) } } -impl Drop for OpenGuard -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ +impl Drop for OpenGuard { fn drop(&mut self) { if let Some(i) = self.inner.take() { let mut inner = i.as_ref().borrow_mut(); @@ -566,20 +493,17 @@ where } } -pub(super) struct Acquired(Key, Option>>>); +pub(super) struct Acquired(Key, Option>>); -impl Acquired -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - pub(super) fn close(&mut self, conn: IoConnection) { +impl Acquired { + pub(super) fn close(&mut self, conn: Connection) { if let Some(inner) = self.1.take() { let (io, _) = conn.into_inner(); inner.as_ref().borrow_mut().release_close(io); } } - pub(super) fn release(&mut self, conn: IoConnection) { + pub(super) fn release(&mut self, conn: Connection) { if let Some(inner) = self.1.take() { let (io, created) = conn.into_inner(); inner @@ -590,7 +514,7 @@ where } } -impl Drop for Acquired { +impl Drop for Acquired { fn drop(&mut self) { if let Some(inner) = self.1.take() { inner.borrow_mut().release(); diff --git a/ntex/src/http/client/ws.rs b/ntex/src/http/client/ws.rs index 19f9deb9..a82a1e71 100644 --- a/ntex/src/http/client/ws.rs +++ b/ntex/src/http/client/ws.rs @@ -6,17 +6,16 @@ use coo_kie::{Cookie, CookieJar}; use nanorand::{Rng, WyRand}; use crate::codec::{AsyncRead, AsyncWrite, Framed}; -use crate::framed::{DispatchItem, Dispatcher, State}; use crate::http::error::HttpError; use crate::http::header::{self, HeaderName, HeaderValue, AUTHORIZATION}; use crate::http::{ConnectionType, Payload, RequestHead, StatusCode, Uri}; +use crate::io::{DefaultFilter, DispatchItem, Dispatcher, Filter, Io, IoBoxed}; use crate::service::{apply_fn, into_service, IntoService, Service}; use crate::util::Either; use crate::{channel::mpsc, rt, time::timeout, util::sink, util::Ready, ws}; pub use crate::ws::{CloseCode, CloseReason, Frame, Message}; -use super::connect::BoxedSocket; use super::error::{InvalidUrl, SendRequestError, WsClientError}; use super::response::ClientResponse; use super::ClientConfig; @@ -311,7 +310,7 @@ impl WsRequest { let fut = self.config.connector.open_tunnel(head.into(), self.addr); // set request timeout - let (head, framed) = if self.config.timeout.non_zero() { + let (head, io, codec) = if self.config.timeout.non_zero() { timeout(self.config.timeout, fut) .await .map_err(|_| SendRequestError::Timeout) @@ -377,13 +376,12 @@ impl WsRequest { // response and ws io Ok(WsConnection::new( ClientResponse::new(head, Payload::None), - framed.map_codec(|_| { - if server_mode { - ws::Codec::new().max_size(max_size) - } else { - ws::Codec::new().max_size(max_size).client_mode() - } - }), + io, + if server_mode { + ws::Codec::new().max_size(max_size) + } else { + ws::Codec::new().max_size(max_size).client_mode() + }, )) } } @@ -403,31 +401,20 @@ impl fmt::Debug for WsRequest { } } -pub struct WsConnection { - io: Io, - state: State, +pub struct WsConnection { + io: IoBoxed, codec: ws::Codec, res: ClientResponse, } -impl WsConnection -where - Io: AsyncRead + AsyncWrite + Unpin + 'static, -{ - fn new(res: ClientResponse, framed: Framed) -> Self { - let (io, codec, state) = State::from_framed(framed); - - Self { - io, - state, - codec, - res, - } +impl WsConnection { + fn new(res: ClientResponse, io: IoBoxed, codec: ws::Codec) -> Self { + Self { io, codec, res } } /// Get ws sink pub fn sink(&self) -> ws::WsSink { - ws::WsSink::new(self.state.clone(), self.codec.clone()) + ws::WsSink::new(self.io.get_ref(), self.codec.clone()) } /// Get reference to response @@ -458,10 +445,10 @@ where } /// Start client websockets service. - pub async fn start(self, service: F) -> Result<(), ws::WsError> + pub async fn start(self, service: U) -> Result<(), ws::WsError> where T: Service> + 'static, - F: IntoService, + U: IntoService, { let service = apply_fn( service.into_service().map_err(ws::WsError::Service), @@ -475,21 +462,22 @@ where DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => { Either::Right(Ready::Err(ws::WsError::Protocol(e))) } - DispatchItem::IoError(e) => { + DispatchItem::Disconnect(Some(e)) => { Either::Right(Ready::Err(ws::WsError::Io(e))) } + DispatchItem::Disconnect(None) => { + Either::Right(Ready::Err(ws::WsError::Disconnected)) + } }, ); - Dispatcher::new(self.io, self.codec, self.state, service, Default::default()) - .await + Dispatcher::new(self.io, self.codec, service, Default::default()).await } /// Consumes the `WsConnection`, returning it'as underlying I/O framed object /// and response. - pub fn into_inner(self) -> (ClientResponse, Framed) { - let framed = self.state.into_framed(self.io, self.codec); - (self.res, framed) + pub fn into_inner(self) -> (ClientResponse, IoBoxed, ws::Codec) { + (self.res, self.io, self.codec) } } diff --git a/ntex/src/http/config.rs b/ntex/src/http/config.rs index 6a5a1d03..e5153e64 100644 --- a/ntex/src/http/config.rs +++ b/ntex/src/http/config.rs @@ -1,7 +1,7 @@ -use std::{cell::Cell, cell::RefCell, ptr::copy_nonoverlapping, rc::Rc, time}; +use std::{cell::Cell, ptr::copy_nonoverlapping, rc::Rc, time}; -use crate::framed::Timer; use crate::http::{Request, Response}; +use crate::io::{IoRef, Timer}; use crate::service::boxed::BoxService; use crate::time::{sleep, Millis, Seconds, Sleep}; use crate::util::{BytesMut, PoolId}; @@ -94,9 +94,9 @@ impl ServiceConfig { } } -pub(super) type OnRequest = BoxService<(Request, Rc>), Request, Response>; +pub(super) type OnRequest = BoxService<(Request, IoRef), Request, Response>; -pub(super) struct DispatcherConfig { +pub(super) struct DispatcherConfig { pub(super) service: S, pub(super) expect: X, pub(super) upgrade: Option, @@ -107,16 +107,16 @@ pub(super) struct DispatcherConfig { pub(super) timer: DateService, pub(super) timer_h1: Timer, pub(super) pool: PoolId, - pub(super) on_request: Option>, + pub(super) on_request: Option, } -impl DispatcherConfig { +impl DispatcherConfig { pub(super) fn new( cfg: ServiceConfig, service: S, expect: X, upgrade: Option, - on_request: Option>, + on_request: Option, ) -> Self { DispatcherConfig { service, @@ -148,10 +148,6 @@ impl DispatcherConfig { self.keep_alive .map(|t| self.timer.now() + time::Duration::from(t)) } - - pub(super) fn now(&self) -> time::Instant { - self.timer.now() - } } const DATE_VALUE_LENGTH_HDR: usize = 39; diff --git a/ntex/src/http/h1/decoder.rs b/ntex/src/http/h1/decoder.rs index f61f34e5..bf21381a 100644 --- a/ntex/src/http/h1/decoder.rs +++ b/ntex/src/http/h1/decoder.rs @@ -322,7 +322,7 @@ impl MessageType for ResponseHead { Err(ParseError::TooLarge) } else { Ok(None) - } + }; } } }; diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index a797a26f..cb8f2b70 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1,14 +1,10 @@ //! Framed transport dispatcher use std::task::{Context, Poll}; -use std::{ - cell::RefCell, error::Error, fmt, future::Future, marker, net, pin::Pin, rc::Rc, - time, -}; +use std::{error::Error, fmt, future::Future, marker, net, pin::Pin, rc::Rc, time}; -use crate::codec::{AsyncRead, AsyncWrite}; -use crate::framed::{ReadTask, State as IoState, WriteTask}; +use crate::io::{Filter, Io, IoRef}; use crate::service::Service; -use crate::util::Bytes; +use crate::{time::now, util::Bytes, util::Either}; use crate::http; use crate::http::body::{BodySize, MessageBody, ResponseBody}; @@ -37,19 +33,24 @@ bitflags::bitflags! { pin_project_lite::pin_project! { /// Dispatcher for HTTP/1.1 protocol - pub struct Dispatcher { + pub struct Dispatcher { #[pin] call: CallState, st: State, - inner: DispatcherInner, + inner: DispatcherInner, } } +#[derive(derive_more::Display)] enum State { Call, ReadRequest, ReadPayload, - SendPayload { body: ResponseBody }, + #[display(fmt = "State::SendPayload")] + SendPayload { + body: ResponseBody, + }, + #[display(fmt = "State::Upgrade")] Upgrade(Option), Stop, } @@ -65,17 +66,15 @@ pin_project_lite::pin_project! { } } -struct DispatcherInner { - io: Option>>, +struct DispatcherInner { + io: Option>, flags: Flags, codec: Codec, - config: Rc>, - state: IoState, + state: IoRef, + config: Rc>, expire: time::Instant, error: Option, payload: Option<(PayloadDecoder, PayloadSender)>, - peer_addr: Option, - on_connect_data: Option>, _t: marker::PhantomData<(S, B)>, } @@ -93,42 +92,33 @@ enum WritePayloadStatus { Continue, } -impl Dispatcher +impl Dispatcher where - T: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter + 'static, S: Service, S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, X: Service, X::Error: ResponseError, - U: Service, + U: Service, Codec), Response = ()>, U::Error: Error + fmt::Display, { /// Construct new `Dispatcher` instance with outgoing messages stream. pub(in crate::http) fn new( - io: T, - config: Rc>, - peer_addr: Option, - on_connect_data: Option>, + io: Io, + config: Rc>, ) -> Self { + let mut expire = now(); + let state = io.get_ref(); let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled()); - let state = IoState::with_memory_pool(config.pool.into()); - state.set_disconnect_timeout(config.client_disconnect); - - let mut expire = config.timer_h1.now(); - let io = Rc::new(RefCell::new(io)); // slow-request timer if config.client_timeout.non_zero() { - expire += std::time::Duration::from(config.client_timeout); + expire += time::Duration::from(config.client_timeout); config.timer_h1.register(expire, expire, &state); } - // start support io tasks - crate::rt::spawn(ReadTask::new(io.clone(), state.clone())); - crate::rt::spawn(WriteTask::new(io.clone(), state.clone())); - Dispatcher { call: CallState::None, st: State::ReadRequest, @@ -138,27 +128,25 @@ where error: None, payload: None, codec, - config, state, + config, expire, - peer_addr, - on_connect_data, _t: marker::PhantomData, }, } } } -impl Future for Dispatcher +impl Future for Dispatcher where - T: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter, S: Service, S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, X: Service, X::Error: ResponseError + 'static, - U: Service, + U: Service, Codec), Response = ()>, U::Error: Error + fmt::Display + 'static, { type Output = Result<(), DispatchError>; @@ -204,7 +192,6 @@ where ) }); if this.inner.flags.contains(Flags::UPGRADE) { - this.inner.state.stop_io(cx.waker()); *this.st = State::Upgrade(Some(req)); return Poll::Pending; } else { @@ -292,7 +279,6 @@ where log::trace!("keep-alive timeout, close connection"); } *this.st = State::Stop; - continue; } @@ -307,7 +293,7 @@ where req, pl ); - req.head_mut().peer_addr = this.inner.peer_addr; + req.head_mut().io = Some(this.inner.state.clone()); // configure request payload let upgrade = match pl { @@ -340,16 +326,9 @@ where ); } - // set on_connect data - if let Some(ref on_connect) = this.inner.on_connect_data - { - on_connect.set(&mut req.extensions_mut()); - } - if upgrade { // Handle UPGRADE request log::trace!("prep io for upgrade handler"); - this.inner.state.stop_io(cx.waker()); *this.st = State::Upgrade(Some(req)); return Poll::Pending; } else { @@ -361,11 +340,7 @@ where CallState::Filter { fut: f.call(( req, - this.inner - .io - .as_ref() - .unwrap() - .clone(), + this.inner.state.clone(), )), } } else if req.head().expect() { @@ -390,13 +365,13 @@ where if this.inner.flags.contains(Flags::STARTED) && (!this.inner.flags.contains(Flags::KEEPALIVE) || !this.inner.codec.keepalive_enabled() - || this.inner.state.is_io_err()) + || !this.inner.state.is_io_open()) { *this.st = State::Stop; - this.inner.state.dispatcher_stopped(); + this.inner.state.stop_dispatcher(); continue; } - this.inner.state.read().wake(cx.waker()); + let _ = read.poll_ready(cx); return Poll::Pending; } Err(err) => { @@ -418,13 +393,13 @@ where *this.st = State::Stop; continue; } - this.inner.state.register_dispatcher(cx.waker()); + let _ = read.poll_ready(cx); return Poll::Pending; } } // consume request's payload State::ReadPayload => { - if this.inner.state.is_io_err() { + if !this.inner.state.is_io_open() { *this.st = State::Stop; } else { loop { @@ -445,7 +420,7 @@ where } // send response body State::SendPayload { ref mut body } => { - if this.inner.state.is_io_err() { + if !this.inner.state.is_io_open() { *this.st = State::Stop; } else { this.inner.poll_read_payload(cx); @@ -459,7 +434,7 @@ where this.inner .state .write() - .enable_backpressure(Some(cx.waker())); + .enable_backpressure(Some(cx)); return Poll::Pending; } WritePayloadStatus::Continue => (), @@ -470,50 +445,48 @@ where } // stop io tasks and call upgrade service State::Upgrade(ref mut req) => { - // check if all io tasks have been stopped - let io = if Rc::strong_count(this.inner.io.as_ref().unwrap()) == 1 { - if let Ok(io) = Rc::try_unwrap(this.inner.io.take().unwrap()) { - io.into_inner() - } else { - return Poll::Ready(Err(DispatchError::InternalError)); - } - } else { - // wait next task stop - this.inner.state.register_dispatcher(cx.waker()); - return Poll::Pending; - }; log::trace!("initate upgrade handling"); + let io = this.inner.io.take().unwrap(); let req = req.take().unwrap(); *this.st = State::Call; - this.inner.state.reset_io_stop(); // Handle UPGRADE request this.call.set(CallState::Upgrade { fut: this.inner.config.upgrade.as_ref().unwrap().call(( req, io, - this.inner.state.clone(), this.inner.codec.clone(), )), }); } // prepare to shutdown State::Stop => { - this.inner.state.shutdown_io(); this.inner.unregister_keepalive(); + if this + .inner + .io + .as_ref() + .unwrap() + .poll_shutdown(cx)? + .is_ready() + { + // get io error + if this.inner.error.is_none() { + this.inner.error = + this.inner.state.take_error().map(DispatchError::Io); + } - // get io error - if this.inner.error.is_none() { - this.inner.error = - this.inner.state.take_io_error().map(DispatchError::Io); - } - - return Poll::Ready(if let Some(err) = this.inner.error.take() { - Err(err) + return Poll::Ready( + if let Some(err) = this.inner.error.take() { + Err(err) + } else { + Ok(()) + }, + ); } else { - Ok(()) - }); + return Poll::Pending; + } } } } @@ -536,8 +509,7 @@ where fn reset_keepalive(&mut self) { // re-register keep-alive if self.flags.contains(Flags::KEEPALIVE) && self.config.keep_alive.non_zero() { - let expire = self.config.timer_h1.now() - + std::time::Duration::from(self.config.keep_alive); + let expire = now() + time::Duration::from(self.config.keep_alive); if expire != self.expire { self.config .timer_h1 @@ -571,11 +543,11 @@ where } fn send_response(&mut self, msg: Response<()>, body: ResponseBody) -> State { - trace!("Sending response: {:?} body: {:?}", msg, body.size()); + trace!("sending response: {:?} body: {:?}", msg, body.size()); // we dont need to process responses if socket is disconnected // but we still want to handle requests with app service // so we skip response processing for droppped connection - if !self.state.is_io_err() { + if self.state.is_io_open() { let result = self .state .write() @@ -617,7 +589,7 @@ where ) -> WritePayloadStatus { match item { Some(Ok(item)) => { - trace!("Got response chunk: {:?}", item.len()); + trace!("got response chunk: {:?}", item.len()); match self .state .write() @@ -637,7 +609,7 @@ where } } None => { - trace!("Response payload eof"); + trace!("response payload eof"); if let Err(err) = self.state.write().encode(Message::Chunk(None), &self.codec) { @@ -653,7 +625,7 @@ where } } Some(Err(e)) => { - trace!("Error during response body poll: {:?}", e); + trace!("error during response body poll: {:?}", e); self.error = Some(DispatchError::ResponsePayload(e)); WritePayloadStatus::Next(State::Stop) } @@ -686,13 +658,13 @@ where break; } Ok(None) => { - if self.state.is_io_err() { + if !self.state.is_io_open() { payload.1.set_error(PayloadError::EncodingCorrupted); self.payload = None; self.error = Some(ParseError::Incomplete.into()); return ReadPayloadStatus::Dropped; } else { - read.wake(cx.waker()); + let _ = read.poll_ready(cx); break; } } @@ -737,6 +709,7 @@ mod tests { use crate::http::config::{DispatcherConfig, ServiceConfig}; use crate::http::h1::{ClientCodec, ExpectHandler, UpgradeHandler}; use crate::http::{body, Request, ResponseHead, StatusCode}; + use crate::io::{self as nio, DefaultFilter}; use crate::service::{boxed, fn_service, IntoService}; use crate::util::{lazy, next, Bytes, BytesMut}; use crate::{codec::Decoder, testing::Io, time::sleep, time::Millis}; @@ -747,7 +720,7 @@ mod tests { pub(crate) fn h1( stream: Io, service: F, - ) -> Dispatcher> + ) -> Dispatcher> where F: IntoService, S: Service, @@ -756,7 +729,7 @@ mod tests { B: MessageBody, { Dispatcher::new( - stream, + nio::Io::new(stream), Rc::new(DispatcherConfig::new( ServiceConfig::default(), service.into_service(), @@ -764,8 +737,6 @@ mod tests { None, None, )), - None, - None, ) } @@ -777,20 +748,22 @@ mod tests { S::Response: Into>, B: MessageBody + 'static, { - crate::rt::spawn( - Dispatcher::>::new( - stream, - Rc::new(DispatcherConfig::new( - ServiceConfig::default(), - service.into_service(), - ExpectHandler, - None, - None, - )), + crate::rt::spawn(Dispatcher::< + DefaultFilter, + S, + B, + ExpectHandler, + UpgradeHandler, + >::new( + nio::Io::new(stream), + Rc::new(DispatcherConfig::new( + ServiceConfig::default(), + service.into_service(), + ExpectHandler, None, None, - ), - ); + )), + )); } fn load(decoder: &mut ClientCodec, buf: &mut BytesMut) -> ResponseHead { @@ -806,7 +779,7 @@ mod tests { let data = Rc::new(Cell::new(false)); let data2 = data.clone(); let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( - server, + nio::Io::new(server), Rc::new(DispatcherConfig::new( ServiceConfig::default(), fn_service(|_| { @@ -821,8 +794,6 @@ mod tests { }, ))), )), - None, - None, ); sleep(Millis(50)).await; diff --git a/ntex/src/http/h1/service.rs b/ntex/src/http/h1/service.rs index 30f0a9da..d6888781 100644 --- a/ntex/src/http/h1/service.rs +++ b/ntex/src/http/h1/service.rs @@ -3,35 +3,33 @@ use std::{ task, }; -use crate::codec::{AsyncRead, AsyncWrite}; -use crate::framed::State as IoState; use crate::http::body::MessageBody; use crate::http::config::{DispatcherConfig, OnRequest, ServiceConfig}; use crate::http::error::{DispatchError, ResponseError}; use crate::http::helpers::DataFactory; use crate::http::request::Request; use crate::http::response::Response; +use crate::io::{DefaultFilter, Filter, Io, IoRef}; use crate::service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; -use crate::{rt::net::TcpStream, time::Millis, util::Pool}; +use crate::{time::Millis, util::Pool}; use super::codec::Codec; use super::dispatcher::Dispatcher; use super::{ExpectHandler, UpgradeHandler}; /// `ServiceFactory` implementation for HTTP1 transport -pub struct H1Service> { +pub struct H1Service> { srv: S, cfg: ServiceConfig, expect: X, upgrade: Option, - on_connect: Option Box>>, - on_request: RefCell>>, + on_request: RefCell>, #[allow(dead_code)] handshake_timeout: Millis, - _t: marker::PhantomData<(T, B)>, + _t: marker::PhantomData<(F, B)>, } -impl H1Service +impl H1Service where S: ServiceFactory, S::Error: ResponseError + 'static, @@ -40,15 +38,14 @@ where B: MessageBody, { /// Create new `HttpService` instance with config. - pub(crate) fn with_config>( + pub(crate) fn with_config>( cfg: ServiceConfig, - service: F, + service: U, ) -> Self { H1Service { srv: service.into_factory(), expect: ExpectHandler, upgrade: None, - on_connect: None, on_request: RefCell::new(None), handshake_timeout: cfg.0.ssl_handshake_timeout, _t: marker::PhantomData, @@ -57,53 +54,14 @@ where } } -impl H1Service -where - S: ServiceFactory, - S::Error: ResponseError + 'static, - S::InitError: fmt::Debug, - S::Response: Into>, - S::Future: 'static, - B: MessageBody, - X: ServiceFactory, - X::Error: ResponseError + 'static, - X::InitError: fmt::Debug, - X::Future: 'static, - U: ServiceFactory< - Config = (), - Request = (Request, TcpStream, IoState, Codec), - Response = (), - >, - U::Error: fmt::Display + Error + 'static, - U::InitError: fmt::Debug, - U::Future: 'static, -{ - /// Create simple tcp stream service - pub fn tcp( - self, - ) -> impl ServiceFactory< - Config = (), - Request = TcpStream, - Response = (), - Error = DispatchError, - InitError = (), - > { - pipeline_factory(|io: TcpStream| async move { - let peer_addr = io.peer_addr().ok(); - Ok((io, peer_addr)) - }) - .and_then(self) - } -} - #[cfg(feature = "openssl")] mod openssl { use super::*; - use crate::server::openssl::{Acceptor, SslAcceptor, SslStream}; + use crate::server::openssl::{Acceptor, SslAcceptor, SslFilter}; use crate::server::SslError; - impl H1Service, S, B, X, U> + impl H1Service, S, B, X, U> where S: ServiceFactory, S::Error: ResponseError + 'static, @@ -117,7 +75,7 @@ mod openssl { X::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, SslStream, IoState, Codec), + Request = (Request, Io>, Codec), Response = (), >, U::Error: fmt::Display + Error + 'static, @@ -130,7 +88,7 @@ mod openssl { acceptor: SslAcceptor, ) -> impl ServiceFactory< Config = (), - Request = TcpStream, + Request = Io, Response = (), Error = SslError, InitError = (), @@ -141,71 +99,68 @@ mod openssl { .map_err(SslError::Ssl) .map_init_err(|_| panic!()), ) - .and_then(|io: SslStream| async move { - let peer_addr = io.get_ref().peer_addr().ok(); - Ok((io, peer_addr)) - }) .and_then(self.map_err(SslError::Service)) } } } -#[cfg(feature = "rustls")] -mod rustls { - use super::*; - use crate::server::rustls::{Acceptor, ServerConfig, TlsStream}; - use crate::server::SslError; - use std::fmt; +// #[cfg(feature = "rustls")] +// mod rustls { +// use super::*; +// use crate::server::rustls::{Acceptor, ServerConfig, TlsStream}; +// use crate::server::SslError; +// use std::fmt; - impl H1Service, S, B, X, U> - where - S: ServiceFactory, - S::Error: ResponseError + 'static, - S::InitError: fmt::Debug, - S::Response: Into>, - S::Future: 'static, - B: MessageBody, - X: ServiceFactory, - X::Error: ResponseError + 'static, - X::InitError: fmt::Debug, - X::Future: 'static, - U: ServiceFactory< - Config = (), - Request = (Request, TlsStream, IoState, Codec), - Response = (), - >, - U::Error: fmt::Display + Error + 'static, - U::InitError: fmt::Debug, - U::Future: 'static, - { - /// Create rustls based service - pub fn rustls( - self, - config: ServerConfig, - ) -> impl ServiceFactory< - Config = (), - Request = TcpStream, - Response = (), - Error = SslError, - InitError = (), - > { - pipeline_factory( - Acceptor::new(config) - .timeout(self.handshake_timeout) - .map_err(SslError::Ssl) - .map_init_err(|_| panic!()), - ) - .and_then(|io: TlsStream| async move { - let peer_addr = io.get_ref().0.peer_addr().ok(); - Ok((io, peer_addr)) - }) - .and_then(self.map_err(SslError::Service)) - } - } -} +// impl H1Service, S, B, X, U> +// where +// S: ServiceFactory, +// S::Error: ResponseError + 'static, +// S::InitError: fmt::Debug, +// S::Response: Into>, +// S::Future: 'static, +// B: MessageBody, +// X: ServiceFactory, +// X::Error: ResponseError + 'static, +// X::InitError: fmt::Debug, +// X::Future: 'static, +// U: ServiceFactory< +// Config = (), +// Request = (Request, TlsStream, IoState, Codec), +// Response = (), +// >, +// U::Error: fmt::Display + Error + 'static, +// U::InitError: fmt::Debug, +// U::Future: 'static, +// { +// /// Create rustls based service +// pub fn rustls( +// self, +// config: ServerConfig, +// ) -> impl ServiceFactory< +// Config = (), +// Request = TcpStream, +// Response = (), +// Error = SslError, +// InitError = (), +// > { +// pipeline_factory( +// Acceptor::new(config) +// .timeout(self.handshake_timeout) +// .map_err(SslError::Ssl) +// .map_init_err(|_| panic!()), +// ) +// .and_then(|io: TlsStream| async move { +// let peer_addr = io.get_ref().0.peer_addr().ok(); +// Ok((io, peer_addr)) +// }) +// .and_then(self.map_err(SslError::Service)) +// } +// } +// } -impl H1Service +impl H1Service where + F: Filter, S: ServiceFactory, S::Error: ResponseError + 'static, S::Response: Into>, @@ -213,7 +168,7 @@ where S::Future: 'static, B: MessageBody, { - pub fn expect(self, expect: X1) -> H1Service + pub fn expect(self, expect: X1) -> H1Service where X1: ServiceFactory, X1::Error: ResponseError + 'static, @@ -225,16 +180,15 @@ where cfg: self.cfg, srv: self.srv, upgrade: self.upgrade, - on_connect: self.on_connect, on_request: self.on_request, handshake_timeout: self.handshake_timeout, _t: marker::PhantomData, } } - pub fn upgrade(self, upgrade: Option) -> H1Service + pub fn upgrade(self, upgrade: Option) -> H1Service where - U1: ServiceFactory, + U1: ServiceFactory, Codec), Response = ()>, U1::Error: fmt::Display + Error + 'static, U1::InitError: fmt::Debug, U1::Future: 'static, @@ -244,34 +198,24 @@ where cfg: self.cfg, srv: self.srv, expect: self.expect, - on_connect: self.on_connect, on_request: self.on_request, handshake_timeout: self.handshake_timeout, _t: marker::PhantomData, } } - /// Set on connect callback. - pub(crate) fn on_connect( - mut self, - f: Option Box>>, - ) -> Self { - self.on_connect = f; - self - } - /// Set req request callback. /// /// It get called once per request. - pub(crate) fn on_request(self, f: Option>) -> Self { + pub(crate) fn on_request(self, f: Option) -> Self { *self.on_request.borrow_mut() = f; self } } -impl ServiceFactory for H1Service +impl ServiceFactory for H1Service where - T: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter + 'static, S: ServiceFactory, S::Error: ResponseError + 'static, S::Response: Into>, @@ -282,28 +226,23 @@ where X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, - U: ServiceFactory< - Config = (), - Request = (Request, T, IoState, Codec), - Response = (), - >, + U: ServiceFactory, Codec), Response = ()>, U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, { type Config = (); - type Request = (T, Option); + type Request = Io; type Response = (); type Error = DispatchError; type InitError = (); - type Service = H1ServiceHandler; + type Service = H1ServiceHandler; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { let fut = self.srv.new_service(()); let fut_ex = self.expect.new_service(()); let fut_upg = self.upgrade.as_ref().map(|f| f.new_service(())); - let on_connect = self.on_connect.clone(); let on_request = self.on_request.borrow_mut().take(); let cfg = self.cfg.clone(); @@ -331,7 +270,6 @@ where Ok(H1ServiceHandler { pool, config, - on_connect, _t: marker::PhantomData, }) }) @@ -339,29 +277,28 @@ where } /// `Service` implementation for HTTP1 transport -pub struct H1ServiceHandler { +pub struct H1ServiceHandler { pool: Pool, - config: Rc>, - on_connect: Option Box>>, - _t: marker::PhantomData<(T, B)>, + config: Rc>, + _t: marker::PhantomData<(F, B)>, } -impl Service for H1ServiceHandler +impl Service for H1ServiceHandler where - T: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter + 'static, S: Service, S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, X: Service, X::Error: ResponseError + 'static, - U: Service, + U: Service, Codec), Response = ()>, U::Error: fmt::Display + Error + 'static, { - type Request = (T, Option); + type Request = Io; type Response = (); type Error = DispatchError; - type Future = Dispatcher; + type Future = Dispatcher; fn poll_ready( &self, @@ -429,9 +366,7 @@ where } } - fn call(&self, (io, addr): Self::Request) -> Self::Future { - let on_connect = self.on_connect.as_ref().map(|f| f(&io)); - - Dispatcher::new(io, self.config.clone(), addr, on_connect) + fn call(&self, io: Self::Request) -> Self::Future { + Dispatcher::new(io, self.config.clone()) } } diff --git a/ntex/src/http/h1/upgrade.rs b/ntex/src/http/h1/upgrade.rs index 41080e13..5a0285d9 100644 --- a/ntex/src/http/h1/upgrade.rs +++ b/ntex/src/http/h1/upgrade.rs @@ -2,16 +2,17 @@ use std::{io, marker::PhantomData, task::Context, task::Poll}; use crate::http::h1::Codec; use crate::http::request::Request; -use crate::{framed::State, util::Ready, Service, ServiceFactory}; +use crate::io::Io; +use crate::{util::Ready, Service, ServiceFactory}; -pub struct UpgradeHandler(PhantomData); +pub struct UpgradeHandler(PhantomData); -impl ServiceFactory for UpgradeHandler { +impl ServiceFactory for UpgradeHandler { type Config = (); - type Request = (Request, T, State, Codec); + type Request = (Request, Io, Codec); type Response = (); type Error = io::Error; - type Service = UpgradeHandler; + type Service = UpgradeHandler; type InitError = io::Error; type Future = Ready; @@ -21,8 +22,8 @@ impl ServiceFactory for UpgradeHandler { } } -impl Service for UpgradeHandler { - type Request = (Request, T, State, Codec); +impl Service for UpgradeHandler { + type Request = (Request, Io, Codec); type Response = (); type Error = io::Error; type Future = Ready; diff --git a/ntex/src/http/h2/mod.rs b/ntex/src/http/h2/mod.rs index 6cfb86ad..2d04fa49 100644 --- a/ntex/src/http/h2/mod.rs +++ b/ntex/src/http/h2/mod.rs @@ -4,11 +4,11 @@ use std::task::{Context, Poll}; use h2::RecvStream; -mod dispatcher; -mod service; +//mod dispatcher; +//mod service; -pub use self::dispatcher::Dispatcher; -pub use self::service::H2Service; +//pub use self::dispatcher::Dispatcher; +//pub use self::service::H2Service; use crate::{http::error::PayloadError, util::Bytes, Stream}; /// H2 receive stream diff --git a/ntex/src/http/message.rs b/ntex/src/http/message.rs index a441a8ee..030ed9c0 100644 --- a/ntex/src/http/message.rs +++ b/ntex/src/http/message.rs @@ -6,6 +6,7 @@ use bitflags::bitflags; use crate::http::header::HeaderMap; use crate::http::{header, Method, StatusCode, Uri, Version}; +use crate::io::IoRef; use crate::util::Extensions; /// Represents various types of connection @@ -45,19 +46,19 @@ pub struct RequestHead { pub version: Version, pub headers: HeaderMap, pub extensions: RefCell, - pub peer_addr: Option, + pub io: Option, pub(super) flags: Flags, } impl Default for RequestHead { fn default() -> RequestHead { RequestHead { + io: None, uri: Uri::default(), method: Method::default(), version: Version::HTTP_11, headers: HeaderMap::with_capacity(16), flags: Flags::empty(), - peer_addr: None, extensions: RefCell::new(Extensions::new()), } } @@ -65,6 +66,7 @@ impl Default for RequestHead { impl Head for RequestHead { fn clear(&mut self) { + self.io = None; self.flags = Flags::empty(); self.headers.clear(); self.extensions.get_mut().clear(); diff --git a/ntex/src/http/request.rs b/ntex/src/http/request.rs index 1cbf35ed..ac4112d9 100644 --- a/ntex/src/http/request.rs +++ b/ntex/src/http/request.rs @@ -6,6 +6,7 @@ use crate::http::header::HeaderMap; use crate::http::httpmessage::HttpMessage; use crate::http::message::{Message, RequestHead}; use crate::http::payload::Payload; +use crate::io::IoRef; use crate::util::Extensions; /// Request @@ -126,13 +127,21 @@ impl Request { self.head().method == Method::CONNECT } + /// Io reference for current connection + #[inline] + pub fn io(&self) -> Option<&IoRef> { + self.head().io.as_ref() + } + /// Peer socket address /// /// Peer address is actual socket address, if proxy is used in front of /// ntex http server, then peer address would be address of this proxy. #[inline] pub fn peer_addr(&self) -> Option { - self.head().peer_addr + // TODO! fix + // self.head().peer_addr + None } /// Get request's payload diff --git a/ntex/src/http/service.rs b/ntex/src/http/service.rs index 1f87abd8..f2e6315e 100644 --- a/ntex/src/http/service.rs +++ b/ntex/src/http/service.rs @@ -6,7 +6,7 @@ use std::{ use h2::server::{self, Handshake}; use crate::codec::{AsyncRead, AsyncWrite}; -use crate::framed::State; +use crate::io::{DefaultFilter, Filter, Io, IoRef}; use crate::rt::net::TcpStream; use crate::service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; use crate::time::{Millis, Seconds}; @@ -19,20 +19,20 @@ use super::error::{DispatchError, ResponseError}; use super::helpers::DataFactory; use super::request::Request; use super::response::Response; -use super::{h1, h2::Dispatcher, Protocol}; +//use super::{h1, h2::Dispatcher, Protocol}; +use super::{h1, Protocol}; /// `ServiceFactory` HTTP1.1/HTTP2 transport implementation -pub struct HttpService> { +pub struct HttpService> { srv: S, cfg: ServiceConfig, expect: X, upgrade: Option, - on_connect: Option Box>>, - on_request: cell::RefCell>>, - _t: marker::PhantomData<(T, B)>, + on_request: cell::RefCell>, + _t: marker::PhantomData<(F, B)>, } -impl HttpService +impl HttpService where S: ServiceFactory, S::Error: ResponseError + 'static, @@ -43,13 +43,14 @@ where B: MessageBody + 'static, { /// Create builder for `HttpService` instance. - pub fn build() -> HttpServiceBuilder { + pub fn build() -> HttpServiceBuilder { HttpServiceBuilder::new() } } -impl HttpService +impl HttpService where + F: Filter, S: ServiceFactory, S::Error: ResponseError + 'static, S::InitError: fmt::Debug, @@ -59,7 +60,7 @@ where B: MessageBody + 'static, { /// Create new `HttpService` instance. - pub fn new>(service: F) -> Self { + pub fn new>(service: U) -> Self { let cfg = ServiceConfig::new( KeepAlive::Timeout(Seconds(5)), Millis(5_000), @@ -73,31 +74,30 @@ where srv: service.into_factory(), expect: h1::ExpectHandler, upgrade: None, - on_connect: None, on_request: cell::RefCell::new(None), _t: marker::PhantomData, } } /// Create new `HttpService` instance with config. - pub(crate) fn with_config>( + pub(crate) fn with_config>( cfg: ServiceConfig, - service: F, + service: U, ) -> Self { HttpService { cfg, srv: service.into_factory(), expect: h1::ExpectHandler, upgrade: None, - on_connect: None, on_request: cell::RefCell::new(None), _t: marker::PhantomData, } } } -impl HttpService +impl HttpService where + F: Filter, S: ServiceFactory, S::Error: ResponseError + 'static, S::InitError: fmt::Debug, @@ -111,7 +111,7 @@ where /// Service get called with request that contains `EXPECT` header. /// Service must return request in case of success, in that case /// request will be forwarded to main service. - pub fn expect(self, expect: X1) -> HttpService + pub fn expect(self, expect: X1) -> HttpService where X1: ServiceFactory, X1::Error: ResponseError, @@ -123,7 +123,6 @@ where cfg: self.cfg, srv: self.srv, upgrade: self.upgrade, - on_connect: self.on_connect, on_request: self.on_request, _t: marker::PhantomData, } @@ -133,11 +132,11 @@ where /// /// If service is provided then normal requests handling get halted /// and this service get called with original request and framed object. - pub fn upgrade(self, upgrade: Option) -> HttpService + pub fn upgrade(self, upgrade: Option) -> HttpService where U1: ServiceFactory< Config = (), - Request = (Request, T, State, h1::Codec), + Request = (Request, Io, h1::Codec), Response = (), >, U1::Error: fmt::Display + error::Error + 'static, @@ -149,77 +148,25 @@ where cfg: self.cfg, srv: self.srv, expect: self.expect, - on_connect: self.on_connect, on_request: self.on_request, _t: marker::PhantomData, } } - /// Set on connect callback. - pub(crate) fn on_connect( - mut self, - f: Option Box>>, - ) -> Self { - self.on_connect = f; - self - } - /// Set on request callback. - pub(crate) fn on_request(self, f: Option>) -> Self { + pub(crate) fn on_request(self, f: Option) -> Self { *self.on_request.borrow_mut() = f; self } } -impl HttpService -where - S: ServiceFactory, - S::Error: ResponseError + 'static, - S::InitError: fmt::Debug, - S::Response: Into> + 'static, - S::Future: 'static, - ::Future: 'static, - B: MessageBody + 'static, - X: ServiceFactory, - X::Error: ResponseError + 'static, - X::InitError: fmt::Debug, - X::Future: 'static, - ::Future: 'static, - U: ServiceFactory< - Config = (), - Request = (Request, TcpStream, State, h1::Codec), - Response = (), - >, - U::Error: fmt::Display + error::Error + 'static, - U::InitError: fmt::Debug, - U::Future: 'static, - ::Future: 'static, -{ - /// Create simple tcp stream service - pub fn tcp( - self, - ) -> impl ServiceFactory< - Config = (), - Request = TcpStream, - Response = (), - Error = DispatchError, - InitError = (), - > { - pipeline_factory(|io: TcpStream| async move { - let peer_addr = io.peer_addr().ok(); - Ok((io, Protocol::Http1, peer_addr)) - }) - .and_then(self) - } -} - #[cfg(feature = "openssl")] mod openssl { use super::*; - use crate::server::openssl::{Acceptor, SslAcceptor, SslStream}; + use crate::server::openssl::{Acceptor, SslAcceptor, SslFilter}; use crate::server::SslError; - impl HttpService, S, B, X, U> + impl HttpService, S, B, X, U> where S: ServiceFactory, S::Error: ResponseError + 'static, @@ -235,7 +182,7 @@ mod openssl { ::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, SslStream, State, h1::Codec), + Request = (Request, Io>, h1::Codec), Response = (), >, U::Error: fmt::Display + error::Error + 'static, @@ -249,7 +196,7 @@ mod openssl { acceptor: SslAcceptor, ) -> impl ServiceFactory< Config = (), - Request = TcpStream, + Request = Io, Response = (), Error = SslError, InitError = (), @@ -260,19 +207,6 @@ mod openssl { .map_err(SslError::Ssl) .map_init_err(|_| panic!()), ) - .and_then(|io: SslStream| async move { - let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() { - if protos.windows(2).any(|window| window == b"h2") { - Protocol::Http2 - } else { - Protocol::Http1 - } - } else { - Protocol::Http1 - }; - let peer_addr = io.get_ref().peer_addr().ok(); - Ok((io, proto, peer_addr)) - }) .and_then(self.map_err(SslError::Service)) } } @@ -284,8 +218,9 @@ mod rustls { use crate::server::rustls::{Acceptor, ServerConfig, TlsStream}; use crate::server::SslError; - impl HttpService, S, B, X, U> + impl HttpService where + F: Filter, S: ServiceFactory, S::Error: ResponseError + 'static, S::InitError: fmt::Debug, @@ -300,7 +235,7 @@ mod rustls { ::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, TlsStream, State, h1::Codec), + Request = (Request, Io, h1::Codec), Response = (), >, U::Error: fmt::Display + error::Error + 'static, @@ -311,47 +246,49 @@ mod rustls { /// Create openssl based service pub fn rustls( self, - mut config: ServerConfig, + config: ServerConfig, ) -> impl ServiceFactory< Config = (), - Request = TcpStream, + Request = Io, Response = (), - Error = SslError, + //Error = SslError, + Error = DispatchError, InitError = (), > { - let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()]; - config.alpn_protocols = protos; + self + // let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()]; + // config.alpn_protocols = protos; - pipeline_factory( - Acceptor::new(config) - .timeout(self.cfg.0.ssl_handshake_timeout) - .map_err(SslError::Ssl) - .map_init_err(|_| panic!()), - ) - .and_then(|io: TlsStream| async move { - let proto = io - .get_ref() - .1 - .alpn_protocol() - .and_then(|protos| { - if protos.windows(2).any(|window| window == b"h2") { - Some(Protocol::Http2) - } else { - None - } - }) - .unwrap_or(Protocol::Http1); - let peer_addr = io.get_ref().0.peer_addr().ok(); - Ok((io, proto, peer_addr)) - }) - .and_then(self.map_err(SslError::Service)) + // pipeline_factory( + // Acceptor::new(config) + // .timeout(self.cfg.0.ssl_handshake_timeout) + // .map_err(SslError::Ssl) + // .map_init_err(|_| panic!()), + // ) + // .and_then(|io: TlsStream| async move { + // let proto = io + // .get_ref() + // .1 + // .alpn_protocol() + // .and_then(|protos| { + // if protos.windows(2).any(|window| window == b"h2") { + // Some(Protocol::Http2) + // } else { + // None + // } + // }) + // .unwrap_or(Protocol::Http1); + // let peer_addr = io.get_ref().0.peer_addr().ok(); + // Ok((io, proto, peer_addr)) + // }) + // .and_then(self.map_err(SslError::Service)) } } } -impl ServiceFactory for HttpService +impl ServiceFactory for HttpService where - T: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter + 'static, S: ServiceFactory, S::Error: ResponseError + 'static, S::InitError: fmt::Debug, @@ -364,29 +301,24 @@ where X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, - U: ServiceFactory< - Config = (), - Request = (Request, T, State, h1::Codec), - Response = (), - >, + U: ServiceFactory, h1::Codec), Response = ()>, U::Error: fmt::Display + error::Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, ::Future: 'static, { type Config = (); - type Request = (T, Protocol, Option); + type Request = Io; type Response = (); type Error = DispatchError; type InitError = (); - type Service = HttpServiceHandler; + type Service = HttpServiceHandler; type Future = Pin>>>; fn new_service(&self, _: ()) -> Self::Future { let fut = self.srv.new_service(()); let fut_ex = self.expect.new_service(()); let fut_upg = self.upgrade.as_ref().map(|f| f.new_service(())); - let on_connect = self.on_connect.clone(); let on_request = self.on_request.borrow_mut().take(); let cfg = self.cfg.clone(); @@ -414,7 +346,6 @@ where Ok(HttpServiceHandler { pool, - on_connect, config: Rc::new(config), _t: marker::PhantomData, }) @@ -423,16 +354,15 @@ where } /// `Service` implementation for http transport -pub struct HttpServiceHandler { +pub struct HttpServiceHandler { pool: Pool, - config: Rc>, - on_connect: Option Box>>, - _t: marker::PhantomData<(T, B, X)>, + config: Rc>, + _t: marker::PhantomData<(F, B, X)>, } -impl Service for HttpServiceHandler +impl Service for HttpServiceHandler where - T: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter + 'static, S: Service, S::Error: ResponseError + 'static, S::Future: 'static, @@ -440,13 +370,13 @@ where B: MessageBody + 'static, X: Service, X::Error: ResponseError + 'static, - U: Service, + U: Service, h1::Codec), Response = ()>, U::Error: fmt::Display + error::Error + 'static, { - type Request = (T, Protocol, Option); + type Request = Io; type Response = (); type Error = DispatchError; - type Future = HttpServiceHandlerResponse; + type Future = HttpServiceHandlerResponse; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { let cfg = self.config.as_ref(); @@ -507,46 +437,36 @@ where } } - fn call(&self, (io, proto, peer_addr): Self::Request) -> Self::Future { - log::trace!( - "New http connection protocol {:?} peer address {:?}", - proto, - peer_addr - ); - let on_connect = self.on_connect.as_ref().map(|f| f(&io)); + fn call(&self, io: Self::Request) -> Self::Future { + // log::trace!("New http connection protocol {:?}", proto); - match proto { - Protocol::Http2 => HttpServiceHandlerResponse { - state: ResponseState::H2Handshake { - data: Some(( - server::Builder::new().handshake(io), - self.config.clone(), - on_connect, - peer_addr, - )), - }, - }, - Protocol::Http1 => HttpServiceHandlerResponse { - state: ResponseState::H1 { - fut: h1::Dispatcher::new( - io, - self.config.clone(), - peer_addr, - on_connect, - ), - }, + //match proto { + //Protocol::Http2 => todo!(), + // HttpServiceHandlerResponse { + // state: ResponseState::H2Handshake { + // data: Some(( + // server::Builder::new().handshake(io), + // self.config.clone(), + // on_connect, + // peer_addr, + // )), + // }, + // }, + // Protocol::Http1 => + HttpServiceHandlerResponse { + state: ResponseState::H1 { + fut: h1::Dispatcher::new(io, self.config.clone()), }, + // }, } } } pin_project_lite::pin_project! { - pub struct HttpServiceHandlerResponse + pub struct HttpServiceHandlerResponse where - T: AsyncRead, - T: AsyncWrite, - T: Unpin, - T: 'static, + F: Filter, + F: 'static, S: Service, S::Error: ResponseError, S::Error: 'static, @@ -557,52 +477,50 @@ pin_project_lite::pin_project! { X: Service, X::Error: ResponseError, X::Error: 'static, - U: Service, + U: Service, h1::Codec), Response = ()>, U::Error: fmt::Display, U::Error: error::Error, U::Error: 'static, { #[pin] - state: ResponseState, + state: ResponseState, } } pin_project_lite::pin_project! { #[project = StateProject] - enum ResponseState + enum ResponseState where S: Service, S::Error: ResponseError, S::Error: 'static, - T: AsyncRead, - T: AsyncWrite, - T: Unpin, - T: 'static, + F: Filter, + F: 'static, B: MessageBody, X: Service, X::Error: ResponseError, X::Error: 'static, - U: Service, + U: Service, h1::Codec), Response = ()>, U::Error: fmt::Display, U::Error: error::Error, U::Error: 'static, { - H1 { #[pin] fut: h1::Dispatcher }, - H2 { fut: Dispatcher }, - H2Handshake { data: - Option<( - Handshake, - Rc>, - Option>, - Option, - )>, - }, + H1 { #[pin] fut: h1::Dispatcher }, + // H2 { fut: Dispatcher }, + // H2Handshake { data: + // Option<( + // Handshake, + // Rc>, + // Option>, + // Option, + // )>, + // }, } } -impl Future for HttpServiceHandlerResponse +impl Future for HttpServiceHandlerResponse where - T: AsyncRead + AsyncWrite + Unpin + 'static, + F: Filter + 'static, S: Service, S::Error: ResponseError + 'static, S::Future: 'static, @@ -610,7 +528,7 @@ where B: MessageBody, X: Service, X::Error: ResponseError + 'static, - U: Service, + U: Service, h1::Codec), Response = ()>, U::Error: fmt::Display + error::Error + 'static, { type Output = Result<(), DispatchError>; @@ -620,26 +538,26 @@ where match this.state.project() { StateProject::H1 { fut } => fut.poll(cx), - StateProject::H2 { ref mut fut } => Pin::new(fut).poll(cx), - StateProject::H2Handshake { data } => { - let conn = if let Some(ref mut item) = data { - match Pin::new(&mut item.0).poll(cx) { - Poll::Ready(Ok(conn)) => conn, - Poll::Ready(Err(err)) => { - trace!("H2 handshake error: {}", err); - return Poll::Ready(Err(err.into())); - } - Poll::Pending => return Poll::Pending, - } - } else { - panic!() - }; - let (_, cfg, on_connect, peer_addr) = data.take().unwrap(); - self.as_mut().project().state.set(ResponseState::H2 { - fut: Dispatcher::new(cfg, conn, on_connect, None, peer_addr), - }); - self.poll(cx) - } + // StateProject::H2 { ref mut fut } => Pin::new(fut).poll(cx), + // StateProject::H2Handshake { data } => { + // let conn = if let Some(ref mut item) = data { + // match Pin::new(&mut item.0).poll(cx) { + // Poll::Ready(Ok(conn)) => conn, + // Poll::Ready(Err(err)) => { + // trace!("H2 handshake error: {}", err); + // return Poll::Ready(Err(err.into())); + // } + // Poll::Pending => return Poll::Pending, + // } + // } else { + // panic!() + // }; + // let (_, cfg, on_connect, peer_addr) = data.take().unwrap(); + // self.as_mut().project().state.set(ResponseState::H2 { + // fut: Dispatcher::new(cfg, conn, on_connect, None, peer_addr), + // }); + // self.poll(cx) + // } } } } diff --git a/ntex/src/http/test.rs b/ntex/src/http/test.rs index 4cf59cbf..81316898 100644 --- a/ntex/src/http/test.rs +++ b/ntex/src/http/test.rs @@ -5,12 +5,15 @@ use std::{convert::TryFrom, io, net, str::FromStr, sync::mpsc, thread}; use coo_kie::{Cookie, CookieJar}; use crate::codec::{AsyncRead, AsyncWrite, Framed}; +use crate::io::IoBoxed; use crate::rt::{net::TcpStream, System}; use crate::server::{Server, StreamServiceFactory}; use crate::{time::Millis, time::Seconds, util::Bytes}; use super::client::error::WsClientError; -use super::client::{Client, ClientRequest, ClientResponse, Connector}; +use super::client::{ + ws::WsConnection, Client, ClientRequest, ClientResponse, Connector, +}; use super::error::{HttpError, PayloadError}; use super::header::{HeaderMap, HeaderName, HeaderValue}; use super::payload::Payload; @@ -207,7 +210,7 @@ fn parts(parts: &mut Option) -> &mut Inner { /// assert!(response.status().is_success()); /// } /// ``` -pub fn server>(factory: F) -> TestServer { +pub fn server(factory: F) -> TestServer { let (tx, rx) = mpsc::channel(); // run server in separate thread @@ -318,21 +321,12 @@ impl TestServer { } /// Connect to websocket server at a given path - pub async fn ws_at( - &mut self, - path: &str, - ) -> Result, WsClientError> - { - let url = self.url(path); - let connect = self.client.ws(url).connect(); - connect.await.map(|ws| ws.into_inner().1) + pub async fn ws_at(&mut self, path: &str) -> Result { + self.client.ws(self.url(path)).connect().await } /// Connect to a websocket server - pub async fn ws( - &mut self, - ) -> Result, WsClientError> - { + pub async fn ws(&mut self) -> Result { self.ws_at("/").await } diff --git a/ntex/src/lib.rs b/ntex/src/lib.rs index f13c2721..c93d80c1 100644 --- a/ntex/src/lib.rs +++ b/ntex/src/lib.rs @@ -7,12 +7,12 @@ //! * `compress` - enables compression support in http and web modules //! * `cookie` - enables cookie support in http and web modules -#![warn( - rust_2018_idioms, - unreachable_pub, - // missing_debug_implementations, - // missing_docs, -)] +//#![warn( +// rust_2018_idioms, +// unreachable_pub, +// missing_debug_implementations, +// missing_docs, +//)] #![allow( type_alias_bounds, clippy::type_complexity, @@ -21,6 +21,7 @@ clippy::too_many_arguments, clippy::new_without_default )] +#![allow(unused_imports)] #[macro_use] extern crate log; @@ -35,7 +36,7 @@ pub(crate) use ntex_macros::rt_test2 as rt_test; pub mod channel; pub mod connect; -pub mod framed; +//pub mod framed; #[cfg(feature = "http-framework")] pub mod http; pub mod server; diff --git a/ntex/src/server/builder.rs b/ntex/src/server/builder.rs index 4e258b04..1e92a658 100644 --- a/ntex/src/server/builder.rs +++ b/ntex/src/server/builder.rs @@ -193,7 +193,7 @@ impl ServerBuilder { factory: F, ) -> io::Result where - F: StreamServiceFactory, + F: StreamServiceFactory, U: net::ToSocketAddrs, { let sockets = bind_addr(addr, self.backlog)?; @@ -219,7 +219,7 @@ impl ServerBuilder { /// Add new unix domain service to the server. pub fn bind_uds(self, name: N, addr: U, factory: F) -> io::Result where - F: StreamServiceFactory, + F: StreamServiceFactory, N: AsRef, U: AsRef, { @@ -249,7 +249,7 @@ impl ServerBuilder { factory: F, ) -> io::Result where - F: StreamServiceFactory, + F: StreamServiceFactory, { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; let token = self.token.next(); @@ -273,7 +273,7 @@ impl ServerBuilder { factory: F, ) -> io::Result where - F: StreamServiceFactory, + F: StreamServiceFactory, { let token = self.token.next(); self.services.push(Factory::create( diff --git a/ntex/src/server/config.rs b/ntex/src/server/config.rs index aab6940d..17b17ae7 100644 --- a/ntex/src/server/config.rs +++ b/ntex/src/server/config.rs @@ -6,8 +6,8 @@ use std::{ use log::error; use crate::rt::net::TcpStream; -use crate::service; use crate::util::{counter::CounterGuard, HashMap, Ready}; +use crate::{io::Io, service}; use super::builder::bind_addr; use super::service::{ @@ -199,7 +199,7 @@ impl InternalServiceFactory for ConfiguredService { res.push(( token, Box::new(StreamService::new(service::fn_service( - move |_: TcpStream| { + move |_: Io| { error!("Service {:?} is not configured", name); Ready::<_, ()>::Ok(()) }, @@ -292,7 +292,7 @@ impl ServiceRuntime { pub fn service(&self, name: &str, service: F) where F: service::IntoServiceFactory, - T: service::ServiceFactory + 'static, + T: service::ServiceFactory + 'static, T::Future: 'static, T::Service: 'static, T::InitError: fmt::Debug, @@ -338,7 +338,7 @@ struct ServiceFactory { impl service::ServiceFactory for ServiceFactory where - T: service::ServiceFactory, + T: service::ServiceFactory, T::Future: 'static, T::Service: 'static, T::Error: 'static, diff --git a/ntex/src/server/mod.rs b/ntex/src/server/mod.rs index 85b21695..d7edeba3 100644 --- a/ntex/src/server/mod.rs +++ b/ntex/src/server/mod.rs @@ -30,9 +30,6 @@ pub use self::config::{ServiceConfig, ServiceRuntime}; pub use self::service::StreamServiceFactory; pub use self::test::{build_test_server, test_server, TestServer}; -#[doc(hidden)] -pub use self::socket::FromStream; - #[non_exhaustive] #[derive(Copy, Clone, Debug, PartialEq, Eq)] /// Server readiness status diff --git a/ntex/src/server/openssl.rs b/ntex/src/server/openssl.rs index f314465c..507a47e4 100644 --- a/ntex/src/server/openssl.rs +++ b/ntex/src/server/openssl.rs @@ -1,10 +1,13 @@ use std::task::{Context, Poll}; use std::{error::Error, fmt, future::Future, io, marker::PhantomData, pin::Pin}; +pub use ntex_openssl::SslFilter; pub use open_ssl::ssl::{self, AlpnError, Ssl, SslAcceptor, SslAcceptorBuilder}; -pub use tokio_openssl::SslStream; + +use ntex_openssl::SslAcceptor as IoSslAcceptor; use crate::codec::{AsyncRead, AsyncWrite}; +use crate::io::{Filter, FilterFactory, Io}; use crate::service::{Service, ServiceFactory}; use crate::time::{sleep, Millis, Sleep}; use crate::util::{counter::Counter, counter::CounterGuard, Ready}; @@ -14,19 +17,17 @@ use super::MAX_SSL_ACCEPT_COUNTER; /// Support `TLS` server connections via openssl package /// /// `openssl` feature enables `Acceptor` type -pub struct Acceptor { - acceptor: SslAcceptor, - timeout: Millis, - io: PhantomData, +pub struct Acceptor { + acceptor: IoSslAcceptor, + _t: PhantomData, } -impl Acceptor { +impl Acceptor { /// Create default openssl acceptor service pub fn new(acceptor: SslAcceptor) -> Self { Acceptor { - acceptor, - timeout: Millis(5_000), - io: PhantomData, + acceptor: IoSslAcceptor::new(acceptor), + _t: PhantomData, } } @@ -34,30 +35,26 @@ impl Acceptor { /// /// Default is set to 5 seconds. pub fn timeout>(mut self, timeout: U) -> Self { - self.timeout = timeout.into(); + self.acceptor.timeout(timeout); self } } -impl Clone for Acceptor { +impl Clone for Acceptor { fn clone(&self) -> Self { Self { acceptor: self.acceptor.clone(), - timeout: self.timeout, - io: PhantomData, + _t: PhantomData, } } } -impl ServiceFactory for Acceptor -where - T: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, -{ - type Request = T; - type Response = SslStream; +impl ServiceFactory for Acceptor { + type Request = Io; + type Response = Io>; type Error = Box; type Config = (); - type Service = AcceptorService; + type Service = AcceptorService; type InitError = (); type Future = Ready; @@ -66,28 +63,23 @@ where Ready::Ok(AcceptorService { acceptor: self.acceptor.clone(), conns: conns.priv_clone(), - timeout: self.timeout, - io: PhantomData, + _t: PhantomData, }) }) } } -pub struct AcceptorService { - acceptor: SslAcceptor, +pub struct AcceptorService { + acceptor: IoSslAcceptor, conns: Counter, - timeout: Millis, - io: PhantomData, + _t: PhantomData, } -impl Service for AcceptorService -where - T: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static, -{ - type Request = T; - type Response = SslStream; +impl Service for AcceptorService { + type Request = Io; + type Response = Io>; type Error = Box; - type Future = AcceptorServiceResponse; + type Future = AcceptorServiceResponse; #[inline] fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll> { @@ -100,57 +92,29 @@ where #[inline] fn call(&self, req: Self::Request) -> Self::Future { - let ssl = Ssl::new(self.acceptor.context()) - .expect("Provided SSL acceptor was invalid."); AcceptorServiceResponse { _guard: self.conns.get(), - io: None, - delay: self.timeout.map(sleep), - io_factory: Some(SslStream::new(ssl, req)), + fut: self.acceptor.clone().create(req), } } } -pub struct AcceptorServiceResponse -where - T: AsyncRead, - T: AsyncWrite, -{ - io: Option>, - delay: Option, - io_factory: Option, open_ssl::error::ErrorStack>>, - _guard: CounterGuard, -} - -impl Future for AcceptorServiceResponse { - type Output = Result, Box>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut(); - - if let Some(ref delay) = this.delay { - match delay.poll_elapsed(cx) { - Poll::Pending => (), - Poll::Ready(_) => { - return Poll::Ready(Err(Box::new(io::Error::new( - io::ErrorKind::TimedOut, - "ssl handshake timeout", - )))) - } - } - } - - match this.io_factory.take() { - Some(Ok(io)) => this.io = Some(io), - Some(Err(err)) => return Poll::Ready(Err(Box::new(err))), - None => (), - } - - let io = this.io.as_mut().unwrap(); - match Pin::new(io).poll_accept(cx) { - Poll::Ready(Ok(_)) => Poll::Ready(Ok(this.io.take().unwrap())), - Poll::Ready(Err(e)) => Poll::Ready(Err(Box::new(e))), - Poll::Pending => Poll::Pending, - } +pin_project_lite::pin_project! { + pub struct AcceptorServiceResponse + where + F: Filter, + F: 'static, + { + #[pin] + fut: >::Future, + _guard: CounterGuard, + } +} + +impl Future for AcceptorServiceResponse { + type Output = Result>, Box>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().fut.poll(cx) } } diff --git a/ntex/src/server/service.rs b/ntex/src/server/service.rs index 74ddf0e4..64cb05f1 100644 --- a/ntex/src/server/service.rs +++ b/ntex/src/server/service.rs @@ -1,3 +1,4 @@ +use std::convert::TryInto; use std::{ future::Future, marker::PhantomData, net::SocketAddr, pin::Pin, task::Context, task::Poll, @@ -5,12 +6,12 @@ use std::{ use log::error; +use crate::io::Io; use crate::service::{Service, ServiceFactory}; use crate::util::{counter::CounterGuard, Ready}; use crate::{rt::spawn, time::Millis}; -use super::socket::{FromStream, Stream}; -use super::Token; +use super::{socket::Stream, Token}; /// Server message pub(super) enum ServerMessage { @@ -22,8 +23,8 @@ pub(super) enum ServerMessage { ForceShutdown, } -pub trait StreamServiceFactory: Send + Clone + 'static { - type Factory: ServiceFactory; +pub trait StreamServiceFactory: Send + Clone + 'static { + type Factory: ServiceFactory; fn create(&self) -> Self::Factory; } @@ -57,12 +58,11 @@ impl StreamService { } } -impl Service for StreamService +impl Service for StreamService where - T: Service, + T: Service, T::Future: 'static, T::Error: 'static, - I: FromStream, { type Request = (Option, ServerMessage); type Response = (); @@ -82,7 +82,7 @@ where fn call(&self, (guard, req): (Option, ServerMessage)) -> Self::Future { match req { ServerMessage::Connect(stream) => { - let stream = FromStream::from_stream(stream).map_err(|e| { + let stream = stream.try_into().map_err(|e| { error!("Cannot convert to an async io stream: {}", e); }); @@ -102,18 +102,16 @@ where } } -pub(super) struct Factory, Io: FromStream> { +pub(super) struct Factory { name: String, inner: F, token: Token, addr: SocketAddr, - _t: PhantomData, } -impl Factory +impl Factory where - F: StreamServiceFactory, - Io: FromStream + Send + 'static, + F: StreamServiceFactory, { pub(crate) fn create( name: String, @@ -126,15 +124,13 @@ where token, inner, addr, - _t: PhantomData, }) } } -impl InternalServiceFactory for Factory +impl InternalServiceFactory for Factory where - F: StreamServiceFactory, - Io: FromStream + Send + 'static, + F: StreamServiceFactory, { fn name(&self, _: Token) -> &str { &self.name @@ -146,7 +142,6 @@ where inner: self.inner.clone(), token: self.token, addr: self.addr, - _t: PhantomData, }) } @@ -187,11 +182,10 @@ impl InternalServiceFactory for Box { } } -impl StreamServiceFactory for F +impl StreamServiceFactory for F where F: Fn() -> T + Send + Clone + 'static, - T: ServiceFactory, - I: FromStream, + T: ServiceFactory, { type Factory = T; diff --git a/ntex/src/server/socket.rs b/ntex/src/server/socket.rs index 76e253c1..6bc10be8 100644 --- a/ntex/src/server/socket.rs +++ b/ntex/src/server/socket.rs @@ -1,6 +1,7 @@ -use std::{fmt, io, net}; +use std::{convert::TryFrom, fmt, io, net}; use crate::codec::{AsyncRead, AsyncWrite}; +use crate::io::{Io, IoStream}; use crate::rt::net::TcpStream; pub(crate) enum Listener { @@ -146,32 +147,29 @@ pub enum Stream { Uds(mio::net::UnixStream), } -pub trait FromStream: AsyncRead + AsyncWrite + Sized { - fn from_stream(stream: Stream) -> io::Result; -} +impl TryFrom for Io { + type Error = io::Error; -#[cfg(unix)] -impl FromStream for TcpStream { - fn from_stream(sock: Stream) -> io::Result { + fn try_from(sock: Stream) -> Result { + #[cfg(unix)] match sock { Stream::Tcp(stream) => { use std::os::unix::io::{FromRawFd, IntoRawFd}; let fd = IntoRawFd::into_raw_fd(stream); let io = TcpStream::from_std(unsafe { FromRawFd::from_raw_fd(fd) })?; io.set_nodelay(true)?; - Ok(io) + Ok(Io::new(io)) } - #[cfg(unix)] - Stream::Uds(_) => { - panic!("Should not happen, bug in server impl"); + Stream::Uds(stream) => { + use crate::rt::net::UnixStream; + use std::os::unix::io::{FromRawFd, IntoRawFd}; + let fd = IntoRawFd::into_raw_fd(stream); + let ud = UnixStream::from_std(unsafe { FromRawFd::from_raw_fd(fd) }); + todo!() } } - } -} -#[cfg(windows)] -impl FromStream for TcpStream { - fn from_stream(sock: Stream) -> io::Result { + #[cfg(windows)] match sock { Stream::Tcp(stream) => { use std::os::windows::io::{FromRawSocket, IntoRawSocket}; @@ -179,26 +177,7 @@ impl FromStream for TcpStream { let io = TcpStream::from_std(unsafe { FromRawSocket::from_raw_socket(fd) })?; io.set_nodelay(true)?; - Ok(io) - } - #[cfg(unix)] - Stream::Uds(_) => { - panic!("Should not happen, bug in server impl"); - } - } - } -} - -#[cfg(unix)] -impl FromStream for crate::rt::net::UnixStream { - fn from_stream(sock: Stream) -> io::Result { - match sock { - Stream::Tcp(_) => panic!("Should not happen, bug in server impl"), - Stream::Uds(stream) => { - use crate::rt::net::UnixStream; - use std::os::unix::io::{FromRawFd, IntoRawFd}; - let fd = IntoRawFd::into_raw_fd(stream); - UnixStream::from_std(unsafe { FromRawFd::from_raw_fd(fd) }) + Ok(Io::new(io)) } } } diff --git a/ntex/src/server/test.rs b/ntex/src/server/test.rs index 741b1f64..87eed26d 100644 --- a/ntex/src/server/test.rs +++ b/ntex/src/server/test.rs @@ -37,7 +37,7 @@ use crate::server::{Server, ServerBuilder, StreamServiceFactory}; /// assert!(response.status().is_success()); /// } /// ``` -pub fn test_server>(factory: F) -> TestServer { +pub fn test_server(factory: F) -> TestServer { let (tx, rx) = mpsc::channel(); // run server in separate thread diff --git a/ntex/src/web/httprequest.rs b/ntex/src/web/httprequest.rs index 5750b76e..c55a0307 100644 --- a/ntex/src/web/httprequest.rs +++ b/ntex/src/web/httprequest.rs @@ -3,6 +3,7 @@ use std::{cell::Ref, cell::RefCell, cell::RefMut, fmt, net, rc::Rc}; use crate::http::{ HeaderMap, HttpMessage, Message, Method, Payload, RequestHead, Uri, Version, }; +use crate::io::IoRef; use crate::router::Path; use crate::util::{Extensions, Ready}; @@ -105,6 +106,12 @@ impl HttpRequest { } } + /// Io reference for current connection + #[inline] + pub fn io(&self) -> Option<&IoRef> { + self.head().io.as_ref() + } + /// Get a reference to the Path parameters. /// /// Params is a container for url parameters. @@ -183,17 +190,6 @@ impl HttpRequest { &self.0.rmap } - /// Peer socket address - /// - /// Peer address is actual socket address, if proxy is used in front of - /// ntex http server, then peer address would be address of this proxy. - /// - /// To get client connection information `.connection_info()` should be used. - #[inline] - pub fn peer_addr(&self) -> Option { - self.head().peer_addr - } - /// Get *ConnectionInfo* for the current request. /// /// This method panics if request's extensions container is already diff --git a/ntex/src/web/info.rs b/ntex/src/web/info.rs index f1b9c682..c0af010f 100644 --- a/ntex/src/web/info.rs +++ b/ntex/src/web/info.rs @@ -119,7 +119,9 @@ impl ConnectionInfo { } if remote.is_none() { // get peeraddr from socketaddr - peer = req.peer_addr.map(|addr| format!("{}", addr)); + + // TODO! fix + // peer = req.peer_addr.map(|addr| format!("{}", addr)); } } diff --git a/ntex/src/web/request.rs b/ntex/src/web/request.rs index bdf985e7..a04b9f41 100644 --- a/ntex/src/web/request.rs +++ b/ntex/src/web/request.rs @@ -6,6 +6,7 @@ use std::{fmt, net}; use crate::http::{ header, HeaderMap, HttpMessage, Method, Payload, RequestHead, Response, Uri, Version, }; +use crate::io::IoRef; use crate::router::{Path, Resource}; use crate::util::Extensions; @@ -87,6 +88,12 @@ impl WebRequest { WebResponse::new(res.into(), self.req) } + /// Io reference for current connection + #[inline] + pub fn io(&self) -> Option<&IoRef> { + self.head().io.as_ref() + } + /// This method returns reference to the request head #[inline] pub fn head(&self) -> &RequestHead { @@ -147,17 +154,6 @@ impl WebRequest { } } - /// Peer socket address - /// - /// Peer address is actual socket address, if proxy is used in front of - /// ntex http server, then peer address would be address of this proxy. - /// - /// To get client connection information `ConnectionInfo` should be used. - #[inline] - pub fn peer_addr(&self) -> Option { - self.head().peer_addr - } - /// Get *ConnectionInfo* for the current request. #[inline] pub fn connection_info(&self) -> Ref<'_, ConnectionInfo> { diff --git a/ntex/src/web/server.rs b/ntex/src/web/server.rs index 1527ee67..cc696fc8 100644 --- a/ntex/src/web/server.rs +++ b/ntex/src/web/server.rs @@ -2,8 +2,8 @@ use std::{fmt, io, marker::PhantomData, net, sync::Arc, sync::Mutex}; #[cfg(feature = "openssl")] use crate::server::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; -#[cfg(feature = "rustls")] -use crate::server::rustls::ServerConfig as RustlsServerConfig; +//#[cfg(feature = "rustls")] +//use crate::server::rustls::ServerConfig as RustlsServerConfig; #[cfg(unix)] use crate::http::Protocol; @@ -275,7 +275,6 @@ where .disconnect_timeout(c.client_disconnect) .memory_pool(c.pool) .finish(map_config(factory(), move |_| cfg.clone())) - .tcp() }, )?; Ok(self) @@ -326,50 +325,50 @@ where Ok(self) } - #[cfg(feature = "rustls")] - /// Use listener for accepting incoming tls connection requests - /// - /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn listen_rustls( - self, - lst: net::TcpListener, - config: RustlsServerConfig, - ) -> io::Result { - self.listen_rustls_inner(lst, config) - } + // #[cfg(feature = "rustls")] + // /// Use listener for accepting incoming tls connection requests + // /// + // /// This method sets alpn protocols to "h2" and "http/1.1" + // pub fn listen_rustls( + // self, + // lst: net::TcpListener, + // config: RustlsServerConfig, + // ) -> io::Result { + // self.listen_rustls_inner(lst, config) + // } - #[cfg(feature = "rustls")] - fn listen_rustls_inner( - mut self, - lst: net::TcpListener, - config: RustlsServerConfig, - ) -> io::Result { - let factory = self.factory.clone(); - let cfg = self.config.clone(); - let addr = lst.local_addr().unwrap(); + // #[cfg(feature = "rustls")] + // fn listen_rustls_inner( + // mut self, + // lst: net::TcpListener, + // config: RustlsServerConfig, + // ) -> io::Result { + // let factory = self.factory.clone(); + // let cfg = self.config.clone(); + // let addr = lst.local_addr().unwrap(); - self.builder = self.builder.listen( - format!("ntex-web-rustls-service-{}", addr), - lst, - move || { - let c = cfg.lock().unwrap(); - let cfg = AppConfig::new( - true, - addr, - c.host.clone().unwrap_or_else(|| format!("{}", addr)), - ); - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .disconnect_timeout(c.client_disconnect) - .ssl_handshake_timeout(c.handshake_timeout) - .memory_pool(c.pool) - .finish(map_config(factory(), move |_| cfg.clone())) - .rustls(config.clone()) - }, - )?; - Ok(self) - } + // self.builder = self.builder.listen( + // format!("ntex-web-rustls-service-{}", addr), + // lst, + // move || { + // let c = cfg.lock().unwrap(); + // let cfg = AppConfig::new( + // true, + // addr, + // c.host.clone().unwrap_or_else(|| format!("{}", addr)), + // ); + // HttpService::build() + // .keep_alive(c.keep_alive) + // .client_timeout(c.client_timeout) + // .disconnect_timeout(c.client_disconnect) + // .ssl_handshake_timeout(c.handshake_timeout) + // .memory_pool(c.pool) + // .finish(map_config(factory(), move |_| cfg.clone())) + // .rustls(config.clone()) + // }, + // )?; + // Ok(self) + // } /// The socket address to bind /// @@ -437,21 +436,21 @@ where Ok(self) } - #[cfg(feature = "rustls")] - /// Start listening for incoming tls connections. - /// - /// This method sets alpn protocols to "h2" and "http/1.1" - pub fn bind_rustls( - mut self, - addr: A, - config: RustlsServerConfig, - ) -> io::Result { - let sockets = self.bind2(addr)?; - for lst in sockets { - self = self.listen_rustls_inner(lst, config.clone())?; - } - Ok(self) - } + // #[cfg(feature = "rustls")] + // /// Start listening for incoming tls connections. + // /// + // /// This method sets alpn protocols to "h2" and "http/1.1" + // pub fn bind_rustls( + // mut self, + // addr: A, + // config: RustlsServerConfig, + // ) -> io::Result { + // let sockets = self.bind2(addr)?; + // for lst in sockets { + // self = self.listen_rustls_inner(lst, config.clone())?; + // } + // Ok(self) + // } #[cfg(unix)] /// Start listening for unix domain connections on existing listener. @@ -479,16 +478,11 @@ where socket_addr, c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), ); - pipeline_factory(|io: UnixStream| { - crate::util::Ready::Ok((io, Protocol::Http1, None)) - }) - .and_then( - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .memory_pool(c.pool) - .finish(map_config(factory(), move |_| config.clone())), - ) + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .memory_pool(c.pool) + .finish(map_config(factory(), move |_| config.clone())) })?; Ok(self) } @@ -520,16 +514,11 @@ where socket_addr, c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), ); - pipeline_factory(|io: UnixStream| { - crate::util::Ready::Ok((io, Protocol::Http1, None)) - }) - .and_then( - HttpService::build() - .keep_alive(c.keep_alive) - .client_timeout(c.client_timeout) - .memory_pool(c.pool) - .finish(map_config(factory(), move |_| config.clone())), - ) + HttpService::build() + .keep_alive(c.keep_alive) + .client_timeout(c.client_timeout) + .memory_pool(c.pool) + .finish(map_config(factory(), move |_| config.clone())) }, )?; Ok(self) diff --git a/ntex/src/web/test.rs b/ntex/src/web/test.rs index 3c27f120..3fb462b3 100644 --- a/ntex/src/web/test.rs +++ b/ntex/src/web/test.rs @@ -465,15 +465,12 @@ impl TestRequest { /// Complete request creation and generate `Request` instance pub fn to_request(mut self) -> Request { - let mut req = self.req.finish(); - req.head_mut().peer_addr = self.peer_addr; - req + self.req.finish() } /// Complete request creation and generate `WebRequest` instance pub fn to_srv_request(mut self) -> WebRequest { - let (mut head, payload) = self.req.finish().into_parts(); - head.peer_addr = self.peer_addr; + let (head, payload) = self.req.finish().into_parts(); *self.path.get_mut() = head.uri.clone(); WebRequest::new(HttpRequest::new( @@ -494,8 +491,7 @@ impl TestRequest { /// Complete request creation and generate `HttpRequest` instance pub fn to_http_request(mut self) -> HttpRequest { - let (mut head, payload) = self.req.finish().into_parts(); - head.peer_addr = self.peer_addr; + let (head, payload) = self.req.finish().into_parts(); *self.path.get_mut() = head.uri.clone(); HttpRequest::new( @@ -511,8 +507,7 @@ impl TestRequest { /// Complete request creation and generate `HttpRequest` and `Payload` instances pub fn to_http_parts(mut self) -> (HttpRequest, Payload) { - let (mut head, payload) = self.req.finish().into_parts(); - head.peer_addr = self.peer_addr; + let (head, payload) = self.req.finish().into_parts(); *self.path.get_mut() = head.uri.clone(); let req = HttpRequest::new( @@ -636,7 +631,6 @@ where HttpService::build() .client_timeout(ctimeout) .h1(map_config(factory(), move |_| cfg.clone())) - .tcp() }), HttpVer::Http2 => builder.listen("test", tcp, move || { let cfg = @@ -644,7 +638,6 @@ where HttpService::build() .client_timeout(ctimeout) .h2(map_config(factory(), move |_| cfg.clone())) - .tcp() }), HttpVer::Both => builder.listen("test", tcp, move || { let cfg = @@ -652,7 +645,6 @@ where HttpService::build() .client_timeout(ctimeout) .finish(map_config(factory(), move |_| cfg.clone())) - .tcp() }), }, #[cfg(feature = "openssl")] @@ -842,8 +834,9 @@ impl TestServerConfig { /// Start rustls server #[cfg(feature = "rustls")] pub fn rustls(mut self, config: rust_tls::ServerConfig) -> Self { - self.stream = StreamType::Rustls(config); - self + // self.stream = StreamType::Rustls(config); + // self + unimplemented!() } /// Set server client timeout in seconds for first request. @@ -928,19 +921,12 @@ impl TestServer { } /// Connect to websocket server at a given path - pub async fn ws_at( - &self, - path: &str, - ) -> Result, WsClientError> { - let url = self.url(path); - let connect = self.client.ws(url).connect(); - connect.await + pub async fn ws_at(&self, path: &str) -> Result { + self.client.ws(self.url(path)).connect().await } /// Connect to a websocket server - pub async fn ws( - &self, - ) -> Result, WsClientError> { + pub async fn ws(&self) -> Result { self.ws_at("/").await } diff --git a/ntex/src/ws/mod.rs b/ntex/src/ws/mod.rs index 567ad8a3..cd66fe18 100644 --- a/ntex/src/ws/mod.rs +++ b/ntex/src/ws/mod.rs @@ -25,6 +25,7 @@ pub use self::stream::{StreamDecoder, StreamEncoder}; pub enum WsError { Service(E), KeepAlive, + Disconnected, Protocol(ProtocolError), Io(io::Error), } diff --git a/ntex/src/ws/sink.rs b/ntex/src/ws/sink.rs index bb833687..bedb3833 100644 --- a/ntex/src/ws/sink.rs +++ b/ntex/src/ws/sink.rs @@ -1,18 +1,18 @@ use std::{future::Future, rc::Rc}; -use crate::framed::{OnDisconnect, State}; +use crate::io::{Io, IoRef, OnDisconnect}; use crate::ws; pub struct WsSink(Rc); struct WsSinkInner { - state: State, + io: IoRef, codec: ws::Codec, } impl WsSink { - pub(crate) fn new(state: State, codec: ws::Codec) -> Self { - Self(Rc::new(WsSinkInner { state, codec })) + pub(crate) fn new(io: IoRef, codec: ws::Codec) -> Self { + Self(Rc::new(WsSinkInner { io, codec })) } /// Endcode and send message to the peer. @@ -23,13 +23,13 @@ impl WsSink { let inner = self.0.clone(); async move { - inner.state.write().encode(item, &inner.codec)?; + inner.io.write().encode(item, &inner.codec)?; Ok(()) } } /// Notify when connection get disconnected pub fn on_disconnect(&self) -> OnDisconnect { - self.0.state.on_disconnect() + self.0.io.on_disconnect() } } diff --git a/ntex/tests/http_awc_client.rs b/ntex/tests/http_awc_client.rs index a08829db..cbf1ec4e 100644 --- a/ntex/tests/http_awc_client.rs +++ b/ntex/tests/http_awc_client.rs @@ -215,15 +215,12 @@ async fn test_connection_reuse() { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) - .and_then( - HttpService::new(map_config( - App::new().service( - web::resource("/").route(web::to(|| async { HttpResponse::Ok() })), - ), - |_| AppConfig::default(), - )) - .tcp(), - ) + .and_then(HttpService::new(map_config( + App::new().service( + web::resource("/").route(web::to(|| async { HttpResponse::Ok() })), + ), + |_| AppConfig::default(), + ))) }); let client = Client::build().timeout(Seconds(10)).finish(); @@ -253,15 +250,12 @@ async fn test_connection_force_close() { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) - .and_then( - HttpService::new(map_config( - App::new().service( - web::resource("/").route(web::to(|| async { HttpResponse::Ok() })), - ), - |_| AppConfig::default(), - )) - .tcp(), - ) + .and_then(HttpService::new(map_config( + App::new().service( + web::resource("/").route(web::to(|| async { HttpResponse::Ok() })), + ), + |_| AppConfig::default(), + ))) }); let client = Client::build().timeout(Seconds(10)).finish(); @@ -291,15 +285,12 @@ async fn test_connection_server_close() { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) - .and_then( - HttpService::new(map_config( - App::new().service(web::resource("/").route(web::to(|| async { - HttpResponse::Ok().force_close().finish() - }))), - |_| AppConfig::default(), - )) - .tcp(), - ) + .and_then(HttpService::new(map_config( + App::new().service(web::resource("/").route(web::to(|| async { + HttpResponse::Ok().force_close().finish() + }))), + |_| AppConfig::default(), + ))) }); let client = Client::build().timeout(Seconds(10)).finish(); @@ -329,16 +320,13 @@ async fn test_connection_wait_queue() { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) - .and_then( - HttpService::new(map_config( - App::new().service( - web::resource("/") - .route(web::to(|| async { HttpResponse::Ok().body(STR) })), - ), - |_| AppConfig::default(), - )) - .tcp(), - ) + .and_then(HttpService::new(map_config( + App::new().service( + web::resource("/") + .route(web::to(|| async { HttpResponse::Ok().body(STR) })), + ), + |_| AppConfig::default(), + ))) }); let client = Client::build() @@ -378,15 +366,12 @@ async fn test_connection_wait_queue_force_close() { num2.fetch_add(1, Ordering::Relaxed); ok(io) }) - .and_then( - HttpService::new(map_config( - App::new().service(web::resource("/").route(web::to(|| async { - HttpResponse::Ok().force_close().body(STR) - }))), - |_| AppConfig::default(), - )) - .tcp(), - ) + .and_then(HttpService::new(map_config( + App::new().service(web::resource("/").route(web::to(|| async { + HttpResponse::Ok().force_close().body(STR) + }))), + |_| AppConfig::default(), + ))) }); let client = Client::build() diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index 95d22b5f..330c6fd0 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -8,8 +8,8 @@ use ntex::http::test::server as test_server; use ntex::http::{ body, header, HttpService, KeepAlive, Method, Request, Response, StatusCode, }; -use ntex::time::{sleep, Millis}; -use ntex::{service::fn_service, time::Seconds, util::Bytes, web::error}; +use ntex::time::{sleep, Millis, Seconds}; +use ntex::{service::fn_service, util::Bytes, util::Ready, web::error}; #[ntex::test] async fn test_h1() { @@ -22,7 +22,6 @@ async fn test_h1() { assert!(req.peer_addr().is_some()); future::ok::<_, io::Error>(Response::Ok().finish()) }) - .tcp() }); let response = srv.request(Method::GET, "/").send().await.unwrap(); @@ -41,7 +40,6 @@ async fn test_h1_2() { assert_eq!(req.version(), http::Version::HTTP_11); future::ok::<_, io::Error>(Response::Ok().finish()) }) - .tcp() }); let response = srv.request(Method::GET, "/").send().await.unwrap(); @@ -72,7 +70,6 @@ async fn test_expect_continue() { let _ = req.payload().next().await; Ok::<_, io::Error>(Response::Ok().finish()) })) - .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -101,22 +98,18 @@ async fn test_chunked_payload() { let total_size: usize = chunk_sizes.iter().sum(); let srv = test_server(|| { - HttpService::build() - .h1(fn_service(|mut request: Request| { - request - .take_payload() - .map(|res| match res { - Ok(pl) => pl, - Err(e) => panic!("Error reading payload: {}", e), - }) - .fold(0usize, |acc, chunk| ready(acc + chunk.len())) - .map(|req_size| { - Ok::<_, io::Error>( - Response::Ok().body(format!("size={}", req_size)), - ) - }) - })) - .tcp() + HttpService::build().h1(fn_service(|mut request: Request| { + request + .take_payload() + .map(|res| match res { + Ok(pl) => pl, + Err(e) => panic!("Error reading payload: {}", e), + }) + .fold(0usize, |acc, chunk| ready(acc + chunk.len())) + .map(|req_size| { + Ok::<_, io::Error>(Response::Ok().body(format!("size={}", req_size))) + }) + })) }); let returned_size = { @@ -156,7 +149,6 @@ async fn test_slow_request() { HttpService::build() .client_timeout(Seconds(1)) .finish(|_| future::ok::<_, io::Error>(Response::Ok().finish())) - .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -169,9 +161,7 @@ async fn test_slow_request() { #[ntex::test] async fn test_http1_malformed_request() { let srv = test_server(|| { - HttpService::build() - .h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) - .tcp() + HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -184,9 +174,7 @@ async fn test_http1_malformed_request() { #[ntex::test] async fn test_http1_keepalive() { let srv = test_server(|| { - HttpService::build() - .h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) - .tcp() + HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -207,7 +195,6 @@ async fn test_http1_keepalive_timeout() { HttpService::build() .keep_alive(1) .h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) - .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -225,9 +212,7 @@ async fn test_http1_keepalive_timeout() { #[ntex::test] async fn test_http1_keepalive_close() { let srv = test_server(|| { - HttpService::build() - .h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) - .tcp() + HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -245,9 +230,7 @@ async fn test_http1_keepalive_close() { #[ntex::test] async fn test_http10_keepalive_default_close() { let srv = test_server(|| { - HttpService::build() - .h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) - .tcp() + HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -264,9 +247,7 @@ async fn test_http10_keepalive_default_close() { #[ntex::test] async fn test_http10_keepalive() { let srv = test_server(|| { - HttpService::build() - .h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) - .tcp() + HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -293,7 +274,6 @@ async fn test_http1_keepalive_disabled() { HttpService::build() .keep_alive(KeepAlive::Disabled) .h1(|_| future::ok::<_, io::Error>(Response::Ok().finish())) - .tcp() }); let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); @@ -315,20 +295,18 @@ async fn test_content_length() { }; let srv = test_server(|| { - HttpService::build() - .h1(|req: Request| { - let indx: usize = req.uri().path()[1..].parse().unwrap(); - let statuses = [ - StatusCode::NO_CONTENT, - StatusCode::CONTINUE, - StatusCode::SWITCHING_PROTOCOLS, - StatusCode::PROCESSING, - StatusCode::OK, - StatusCode::NOT_FOUND, - ]; - future::ok::<_, io::Error>(Response::new(statuses[indx])) - }) - .tcp() + HttpService::build().h1(|req: Request| { + let indx: usize = req.uri().path()[1..].parse().unwrap(); + let statuses = [ + StatusCode::NO_CONTENT, + StatusCode::CONTINUE, + StatusCode::SWITCHING_PROTOCOLS, + StatusCode::PROCESSING, + StatusCode::OK, + StatusCode::NOT_FOUND, + ]; + future::ok::<_, io::Error>(Response::new(statuses[indx])) + }) }); let header = HeaderName::from_static("content-length"); @@ -362,7 +340,7 @@ async fn test_h1_headers() { let data = data.clone(); HttpService::build().h1(move |_| { let mut builder = Response::Ok(); - for idx in 0..90 { + for idx in 0..20 { builder.header( format!("X-TEST-{}", idx).as_str(), "TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \ @@ -380,8 +358,9 @@ async fn test_h1_headers() { TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ", ); } - future::ok::<_, io::Error>(builder.body(data.clone())) - }).tcp() + println!("SENDING body"); + Ready::Ok::<_, io::Error>(builder.body(data.clone())) + }) }); let response = srv.request(Method::GET, "/").send().await.unwrap(); @@ -417,9 +396,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ #[ntex::test] async fn test_h1_body() { let mut srv = test_server(|| { - HttpService::build() - .h1(|_| ok::<_, io::Error>(Response::Ok().body(STR))) - .tcp() + HttpService::build().h1(|_| ok::<_, io::Error>(Response::Ok().body(STR))) }); let response = srv.request(Method::GET, "/").send().await.unwrap(); @@ -433,9 +410,7 @@ async fn test_h1_body() { #[ntex::test] async fn test_h1_head_empty() { let mut srv = test_server(|| { - HttpService::build() - .h1(|_| ok::<_, io::Error>(Response::Ok().body(STR))) - .tcp() + HttpService::build().h1(|_| ok::<_, io::Error>(Response::Ok().body(STR))) }); let response = srv.request(http::Method::HEAD, "/").send().await.unwrap(); @@ -457,13 +432,9 @@ async fn test_h1_head_empty() { #[ntex::test] async fn test_h1_head_binary() { let mut srv = test_server(|| { - HttpService::build() - .h1(|_| { - ok::<_, io::Error>( - Response::Ok().content_length(STR.len() as u64).body(STR), - ) - }) - .tcp() + HttpService::build().h1(|_| { + ok::<_, io::Error>(Response::Ok().content_length(STR.len() as u64).body(STR)) + }) }); let response = srv.request(http::Method::HEAD, "/").send().await.unwrap(); @@ -485,9 +456,7 @@ async fn test_h1_head_binary() { #[ntex::test] async fn test_h1_head_binary2() { let srv = test_server(|| { - HttpService::build() - .h1(|_| ok::<_, io::Error>(Response::Ok().body(STR))) - .tcp() + HttpService::build().h1(|_| ok::<_, io::Error>(Response::Ok().body(STR))) }); let response = srv.request(http::Method::HEAD, "/").send().await.unwrap(); @@ -505,14 +474,12 @@ async fn test_h1_head_binary2() { #[ntex::test] async fn test_h1_body_length() { let mut srv = test_server(|| { - HttpService::build() - .h1(|_| { - let body = once(ok(Bytes::from_static(STR.as_ref()))); - ok::<_, io::Error>( - Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), - ) - }) - .tcp() + HttpService::build().h1(|_| { + let body = once(ok(Bytes::from_static(STR.as_ref()))); + ok::<_, io::Error>( + Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)), + ) + }) }); let response = srv.request(Method::GET, "/").send().await.unwrap(); @@ -525,18 +492,15 @@ async fn test_h1_body_length() { #[ntex::test] async fn test_h1_body_chunked_explicit() { - env_logger::init(); let mut srv = test_server(|| { - HttpService::build() - .h1(|_| { - let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref()))); - ok::<_, io::Error>( - Response::Ok() - .header(header::TRANSFER_ENCODING, "chunked") - .streaming(body), - ) - }) - .tcp() + HttpService::build().h1(|_| { + let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, io::Error>( + Response::Ok() + .header(header::TRANSFER_ENCODING, "chunked") + .streaming(body), + ) + }) }); let response = srv.request(Method::GET, "/").send().await.unwrap(); @@ -561,12 +525,10 @@ async fn test_h1_body_chunked_explicit() { #[ntex::test] async fn test_h1_body_chunked_implicit() { let mut srv = test_server(|| { - HttpService::build() - .h1(|_| { - let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref()))); - ok::<_, io::Error>(Response::Ok().streaming(body)) - }) - .tcp() + HttpService::build().h1(|_| { + let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref()))); + ok::<_, io::Error>(Response::Ok().streaming(body)) + }) }); let response = srv.request(Method::GET, "/").send().await.unwrap(); @@ -589,16 +551,14 @@ async fn test_h1_body_chunked_implicit() { #[ntex::test] async fn test_h1_response_http_error_handling() { let mut srv = test_server(|| { - HttpService::build() - .h1(fn_service(|_| { - let broken_header = Bytes::from_static(b"\0\0\0"); - ok::<_, io::Error>( - Response::Ok() - .header(http::header::CONTENT_TYPE, &broken_header[..]) - .body(STR), - ) - })) - .tcp() + HttpService::build().h1(fn_service(|_| { + let broken_header = Bytes::from_static(b"\0\0\0"); + ok::<_, io::Error>( + Response::Ok() + .header(http::header::CONTENT_TYPE, &broken_header[..]) + .body(STR), + ) + })) }); let response = srv.request(Method::GET, "/").send().await.unwrap(); @@ -612,14 +572,12 @@ async fn test_h1_response_http_error_handling() { #[ntex::test] async fn test_h1_service_error() { let mut srv = test_server(|| { - HttpService::build() - .h1(|_| { - future::err::(error::InternalError::default( - "error", - StatusCode::BAD_REQUEST, - )) - }) - .tcp() + HttpService::build().h1(|_| { + future::err::(error::InternalError::default( + "error", + StatusCode::BAD_REQUEST, + )) + }) }); let response = srv.request(Method::GET, "/").send().await.unwrap(); @@ -629,19 +587,3 @@ async fn test_h1_service_error() { let bytes = srv.load_body(response).await.unwrap(); assert_eq!(bytes, Bytes::from_static(b"error")); } - -#[ntex::test] -async fn test_h1_on_connect() { - let srv = test_server(|| { - HttpService::build() - .on_connect(|_| 10usize) - .h1(|req: Request| { - assert!(req.extensions().contains::()); - future::ok::<_, io::Error>(Response::Ok().finish()) - }) - .tcp() - }); - - let response = srv.request(Method::GET, "/").send().await.unwrap(); - assert!(response.status().is_success()); -} diff --git a/ntex/tests/server.rs b/ntex/tests/server.rs index 52980397..2332caaf 100644 --- a/ntex/tests/server.rs +++ b/ntex/tests/server.rs @@ -3,10 +3,9 @@ use std::sync::{mpsc, Arc}; use std::{io, io::Read, net, thread, time}; use futures::future::{lazy, ok, FutureExt}; -use futures::SinkExt; -use ntex::codec::{BytesCodec, Framed}; -use ntex::rt::net::TcpStream; +use ntex::codec::BytesCodec; +use ntex::io::Io; use ntex::server::{Server, TestServer}; use ntex::service::fn_service; use ntex::util::{Bytes, Ready}; @@ -77,9 +76,10 @@ fn test_start() { .backlog(100) .disable_signals() .bind("test", addr, move || { - fn_service(|io: TcpStream| async move { - let mut f = Framed::new(io, BytesCodec); - f.send(Bytes::from_static(b"test")).await.unwrap(); + fn_service(|io: Io| async move { + io.send(Bytes::from_static(b"test"), &BytesCodec) + .await + .unwrap(); Ok::<_, ()>(()) }) })