diff --git a/Cargo.toml b/Cargo.toml index a90cbbcf..d641c490 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "ntex-bytes", "ntex-codec", "ntex-io", + "ntex-openssl", "ntex-router", "ntex-rt", "ntex-service", @@ -16,6 +17,7 @@ ntex = { path = "ntex" } ntex-bytes = { path = "ntex-bytes" } ntex-codec = { path = "ntex-codec" } ntex-io = { path = "ntex-io" } +ntex-openssl = { path = "ntex-openssl" } ntex-router = { path = "ntex-router" } ntex-rt = { path = "ntex-rt" } ntex-service = { path = "ntex-service" } diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml index 51d7ec33..a66efa2e 100644 --- a/ntex-io/Cargo.toml +++ b/ntex-io/Cargo.toml @@ -37,3 +37,4 @@ tok-io = { version = "1", package = "tokio", default-features = false, features ntex = "0.4.13" futures = "0.3.13" rand = "0.8" +env_logger = "0.9" \ No newline at end of file diff --git a/ntex-io/src/dispatcher.rs b/ntex-io/src/dispatcher.rs index 31f3b9e2..66fd9d74 100644 --- a/ntex-io/src/dispatcher.rs +++ b/ntex-io/src/dispatcher.rs @@ -1,6 +1,6 @@ //! Framed transport dispatcher use std::{ - cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll, time, + cell::Cell, future::Future, io, pin::Pin, rc::Rc, task::Context, task::Poll, time, }; use ntex_bytes::Pool; @@ -70,6 +70,7 @@ enum DispatcherError { KeepAlive, Encoder(U), Service(S), + Io(io::Error), } enum PollService { @@ -171,10 +172,19 @@ where { fn handle_result(&self, item: Result, write: WriteRef<'_>) { self.inflight.set(self.inflight.get() - 1); - match write.encode_result(item, &self.codec) { - Ok(true) => (), - Ok(false) => write.enable_backpressure(None), - Err(err) => self.error.set(Some(err.into())), + match item { + Ok(Some(val)) => match write.encode(val, &self.codec) { + Ok(true) => (), + Ok(false) => write.enable_backpressure(None), + Err(Either::Left(err)) => { + self.error.set(Some(DispatcherError::Encoder(err))) + } + Err(Either::Right(err)) => { + self.error.set(Some(DispatcherError::Io(err))) + } + }, + Err(err) => self.error.set(Some(DispatcherError::Service(err))), + Ok(None) => return, } write.wake_dispatcher(); } @@ -217,7 +227,10 @@ where match slf.st.get() { DispatcherState::Processing => { let result = match slf.poll_service(this.service, cx, read) { - Poll::Pending => return Poll::Pending, + Poll::Pending => { + let _ = read.poll_ready(cx); + return Poll::Pending; + } Poll::Ready(result) => result, }; @@ -237,8 +250,20 @@ where } Ok(None) => { log::trace!("not enough data to decode next frame, register dispatch task"); - read.wake(cx); - return Poll::Pending; + // service is ready, wake io read task + match read.poll_ready(cx) { + Poll::Pending + | Poll::Ready(Ok(Some(()))) => { + read.resume(); + return Poll::Pending; + } + Poll::Ready(Ok(None)) => { + DispatchItem::Disconnect(None) + } + Poll::Ready(Err(err)) => { + DispatchItem::Disconnect(Some(err)) + } + } } Err(err) => { slf.st.set(DispatcherState::Stop); @@ -248,8 +273,18 @@ where } } else { // no new events - state.register_dispatcher(cx); - return Poll::Pending; + match read.poll_ready(cx) { + Poll::Pending | Poll::Ready(Ok(Some(()))) => { + read.resume(); + return Poll::Pending; + } + Poll::Ready(Ok(None)) => { + DispatchItem::Disconnect(None) + } + Poll::Ready(Err(err)) => { + DispatchItem::Disconnect(Some(err)) + } + } } } PollService::Item(item) => item, @@ -318,7 +353,7 @@ where if slf.shared.inflight.get() == 0 { slf.st.set(DispatcherState::Shutdown); - state.shutdown(cx); + state.init_shutdown(cx); } else { state.register_dispatcher(cx); return Poll::Pending; @@ -368,15 +403,19 @@ where item: Result::Item>, S::Error>, write: WriteRef<'_>, ) { - match write.encode_result(item, &self.shared.codec) { - Ok(true) => (), - Ok(false) => write.enable_backpressure(None), - Err(Either::Left(err)) => { - self.error.set(Some(err)); - } - Err(Either::Right(err)) => { - self.shared.error.set(Some(DispatcherError::Encoder(err))) - } + match item { + Ok(Some(item)) => match write.encode(item, &self.shared.codec) { + Ok(true) => (), + Ok(false) => write.enable_backpressure(None), + Err(Either::Left(err)) => { + self.shared.error.set(Some(DispatcherError::Encoder(err))) + } + Err(Either::Right(err)) => { + self.shared.error.set(Some(DispatcherError::Io(err))) + } + }, + Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))), + Ok(None) => (), } } @@ -388,9 +427,6 @@ where ) -> Poll> { match srv.poll_ready(cx) { Poll::Ready(Ok(_)) => { - // service is ready, wake io read task - read.resume(); - // check keepalive timeout self.check_keepalive(); @@ -407,6 +443,9 @@ where DispatcherError::Encoder(err) => { PollService::Item(DispatchItem::EncoderError(err)) } + DispatcherError::Io(err) => { + PollService::Item(DispatchItem::Disconnect(Some(err))) + } DispatcherError::Service(err) => { self.error.set(Some(err)); PollService::ServiceError @@ -425,7 +464,7 @@ where // get io error if let Some(err) = self.state.take_error() { - PollService::Item(DispatchItem::IoError(err)) + PollService::Item(DispatchItem::Disconnect(Some(err))) } else { PollService::ServiceError } @@ -803,15 +842,15 @@ mod tests { // response message assert!(!state.write().is_ready()); - assert_eq!(state.write().with_buf(|buf| buf.len()), 65536); + assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 65536); client.remote_buffer_cap(10240); sleep(Millis(50)).await; - assert_eq!(state.write().with_buf(|buf| buf.len()), 55296); + assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 55296); client.remote_buffer_cap(45056); sleep(Millis(50)).await; - assert_eq!(state.write().with_buf(|buf| buf.len()), 10240); + assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 10240); // backpressure disabled assert!(state.write().is_ready()); @@ -821,7 +860,6 @@ mod tests { #[ntex::test] async fn test_keepalive() { let (client, server) = IoTest::create(); - // do not allow to write to socket client.remote_buffer_cap(1024); client.write("GET /test HTTP/1\r\n\r\n"); @@ -854,8 +892,7 @@ mod tests { .keepalive_timeout(Seconds(1)) .await; }); - - state.0.disconnect_timeout.set(Seconds(1)); + state.0.disconnect_timeout.set(Millis::ONE_SEC); let buf = client.read().await.unwrap(); assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs index a17b38d9..961dcfc8 100644 --- a/ntex-io/src/filter.rs +++ b/ntex-io/src/filter.rs @@ -2,7 +2,7 @@ use std::{io, rc::Rc, task::Context, task::Poll}; use ntex_bytes::BytesMut; -use super::state::{Flags, IoStateInner}; +use super::state::{Flags, IoRef, IoStateInner}; use super::{Filter, ReadFilter, WriteFilter, WriteReadiness}; pub struct DefaultFilter(Rc); @@ -13,7 +13,20 @@ impl DefaultFilter { } } -impl Filter for DefaultFilter {} +impl Filter for DefaultFilter { + #[inline] + fn shutdown(&self, _: &IoRef) -> Poll> { + let mut flags = self.0.flags.get(); + if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { + flags.insert(Flags::IO_SHUTDOWN); + self.0.flags.set(flags); + self.0.read_task.wake(); + self.0.write_task.wake(); + } + + Poll::Ready(Ok(())) + } +} impl ReadFilter for DefaultFilter { #[inline] @@ -48,20 +61,20 @@ impl ReadFilter for DefaultFilter { } #[inline] - fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) { - if new_bytes > 0 { - if buf.len() > self.0.pool.get().read_params().high as usize { - log::trace!( - "buffer is too large {}, enable read back-pressure", - buf.len() - ); - self.0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL); - } else { - self.0.insert_flags(Flags::RD_READY); - } - self.0.dispatch_task.wake(); + fn release_read_buf( + &self, + buf: BytesMut, + new_bytes: usize, + ) -> Result<(), io::Error> { + if new_bytes > 0 && buf.len() > self.0.pool.get().read_params().high as usize { + log::trace!( + "buffer is too large {}, enable read back-pressure", + buf.len() + ); + self.0.insert_flags(Flags::RD_BUF_FULL); } self.0.read_buf.set(Some(buf)); + Ok(()) } } @@ -71,12 +84,23 @@ impl WriteFilter for DefaultFilter { &self, cx: &mut Context<'_>, ) -> Poll> { - let flags = self.0.flags.get(); + let mut flags = self.0.flags.get(); if flags.contains(Flags::IO_ERR) { Poll::Ready(Err(WriteReadiness::Terminate)) } else if flags.intersects(Flags::IO_SHUTDOWN) { - Poll::Ready(Err(WriteReadiness::Shutdown)) + Poll::Ready(Err(WriteReadiness::Shutdown( + self.0.disconnect_timeout.get(), + ))) + } else if flags.contains(Flags::IO_FILTERS) + && !flags.contains(Flags::IO_FILTERS_TO) + { + flags.insert(Flags::IO_FILTERS_TO); + self.0.flags.set(flags); + self.0.write_task.register(cx.waker()); + Poll::Ready(Err(WriteReadiness::Timeout( + self.0.disconnect_timeout.get(), + ))) } else { self.0.write_task.register(cx.waker()); Poll::Ready(Ok(())) @@ -100,13 +124,15 @@ impl WriteFilter for DefaultFilter { } #[inline] - fn release_write_buf(&self, buf: BytesMut) { + fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> { let pool = self.0.pool.get(); if buf.is_empty() { pool.release_write_buf(buf); } else { self.0.write_buf.set(Some(buf)); + self.0.write_task.wake(); } + Ok(()) } } @@ -120,7 +146,11 @@ impl NullFilter { } } -impl Filter for NullFilter {} +impl Filter for NullFilter { + fn shutdown(&self, _: &IoRef) -> Poll> { + Poll::Ready(Ok(())) + } +} impl ReadFilter for NullFilter { fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll> { @@ -133,7 +163,9 @@ impl ReadFilter for NullFilter { None } - fn release_read_buf(&self, _: BytesMut, _: usize) {} + fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<(), io::Error> { + Ok(()) + } } impl WriteFilter for NullFilter { @@ -147,5 +179,7 @@ impl WriteFilter for NullFilter { None } - fn release_write_buf(&self, _: BytesMut) {} + fn release_write_buf(&self, _: BytesMut) -> Result<(), io::Error> { + Ok(()) + } } diff --git a/ntex-io/src/lib.rs b/ntex-io/src/lib.rs index 99d337d2..b25a04af 100644 --- a/ntex-io/src/lib.rs +++ b/ntex-io/src/lib.rs @@ -14,19 +14,22 @@ mod tokio_impl; use ntex_bytes::BytesMut; use ntex_codec::{Decoder, Encoder}; +use ntex_util::time::Millis; pub use self::dispatcher::Dispatcher; +pub use self::filter::DefaultFilter; pub use self::state::{Io, IoRef, ReadRef, WriteRef}; pub use self::tasks::{ReadState, WriteState}; pub use self::time::Timer; -pub use self::utils::{from_iostream, into_boxed}; +pub use self::utils::{filter_factory, from_iostream, into_boxed, into_io}; pub type IoBoxed = Io>; #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub enum WriteReadiness { - Shutdown, + Timeout(Millis), + Shutdown(Millis), Terminate, } @@ -37,7 +40,8 @@ pub trait ReadFilter { fn get_read_buf(&self) -> Option; - fn release_read_buf(&self, buf: BytesMut, new_bytes: usize); + fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) + -> Result<(), io::Error>; } pub trait WriteFilter { @@ -48,10 +52,12 @@ pub trait WriteFilter { fn get_write_buf(&self) -> Option; - fn release_write_buf(&self, buf: BytesMut); + fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>; } -pub trait Filter: ReadFilter + WriteFilter {} +pub trait Filter: ReadFilter + WriteFilter { + fn shutdown(&self, st: &IoRef) -> Poll>; +} pub trait FilterFactory: Sized { type Filter: Filter; @@ -59,7 +65,7 @@ pub trait FilterFactory: Sized { type Error: fmt::Debug; type Future: Future, Self::Error>>; - fn create(&self, st: Io) -> Self::Future; + fn create(self, st: Io) -> Self::Future; } pub trait IoStream { @@ -79,8 +85,8 @@ pub enum DispatchItem { DecoderError(::Error), /// Encoder parse error EncoderError(::Error), - /// Unexpected io error - IoError(io::Error), + /// Socket is disconnected + Disconnect(Option), } impl fmt::Debug for DispatchItem @@ -108,8 +114,8 @@ where DispatchItem::DecoderError(ref e) => { write!(fmt, "DispatchItem::DecoderError({:?})", e) } - DispatchItem::IoError(ref e) => { - write!(fmt, "DispatchItem::IoError({:?})", e) + DispatchItem::Disconnect(ref e) => { + write!(fmt, "DispatchItem::Disconnect({:?})", e) } } } @@ -128,8 +134,8 @@ mod tests { assert!(format!("{:?}", err).contains("DispatchItem::Encoder")); let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err")); assert!(format!("{:?}", err).contains("DispatchItem::Decoder")); - let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err")); - assert!(format!("{:?}", err).contains("DispatchItem::IoError")); + let err = T::Disconnect(Some(io::Error::new(io::ErrorKind::Other, "err"))); + assert!(format!("{:?}", err).contains("DispatchItem::Disconnect")); assert!(format!("{:?}", T::WBackPressureEnabled) .contains("DispatchItem::WBackPressureEnabled")); diff --git a/ntex-io/src/state.rs b/ntex-io/src/state.rs index 251cf7a5..8e049789 100644 --- a/ntex-io/src/state.rs +++ b/ntex-io/src/state.rs @@ -4,7 +4,8 @@ use std::{future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc}; use ntex_bytes::{BytesMut, PoolId, PoolRef}; use ntex_codec::{Decoder, Encoder}; -use ntex_util::{future::poll_fn, future::Either, task::LocalWaker, time::Seconds}; +use ntex_util::time::{Millis, Seconds}; +use ntex_util::{future::poll_fn, future::Either, task::LocalWaker}; use super::filter::{DefaultFilter, NullFilter}; use super::tasks::{ReadState, WriteState}; @@ -14,8 +15,12 @@ bitflags::bitflags! { pub struct Flags: u16 { /// io error occured const IO_ERR = 0b0000_0000_0000_0001; + /// shuting down filters + const IO_FILTERS = 0b0000_0000_0000_0010; + /// shuting down filters timeout + const IO_FILTERS_TO = 0b0000_0000_0000_0100; /// shutdown io tasks - const IO_SHUTDOWN = 0b0000_0000_0000_0100; + const IO_SHUTDOWN = 0b0000_0000_0000_1000; /// pause io read const RD_PAUSED = 0b0000_0000_0000_1000; @@ -51,7 +56,7 @@ pub struct IoRef(pub(super) Rc); pub(crate) struct IoStateInner { pub(super) flags: Cell, pub(super) pool: Cell, - pub(super) disconnect_timeout: Cell, + pub(super) disconnect_timeout: Cell, pub(super) error: Cell>, pub(super) read_task: LocalWaker, pub(super) write_task: LocalWaker, @@ -77,6 +82,16 @@ impl IoStateInner { self.flags.set(flags); } + #[inline] + pub(super) fn notify_keepalive(&self) { + let mut flags = self.flags.get(); + if !flags.contains(Flags::DSP_KEEPALIVE) { + flags.insert(Flags::DSP_KEEPALIVE); + self.flags.set(flags); + self.dispatch_task.wake(); + } + } + #[inline] pub(super) fn notify_disconnect(&self) { let mut on_disconnect = self.on_disconnect.borrow_mut(); @@ -86,6 +101,36 @@ impl IoStateInner { } } } + + #[inline] + fn is_io_err(&self) -> bool { + self.flags.get().contains(Flags::IO_ERR) + } + + #[inline] + pub(super) fn shutdown_filters(&self, st: &IoRef) -> Result<(), io::Error> { + let mut flags = self.flags.get(); + if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { + let result = match self.filter.get().shutdown(st) { + Poll::Pending => return Ok(()), + Poll::Ready(Ok(())) => { + flags.insert(Flags::IO_SHUTDOWN); + Ok(()) + } + Poll::Ready(Err(err)) => { + flags.insert(Flags::IO_ERR); + self.dispatch_task.wake(); + Err(err) + } + }; + self.flags.set(flags); + self.read_task.wake(); + self.write_task.wake(); + result + } else { + Ok(()) + } + } } impl Eq for IoStateInner {} @@ -130,7 +175,7 @@ impl Io { pool: Cell::new(pool), flags: Cell::new(Flags::empty()), error: Cell::new(None), - disconnect_timeout: Cell::new(Seconds(1)), + disconnect_timeout: Cell::new(Millis::ONE_SEC), dispatch_task: LocalWaker::new(), read_task: LocalWaker::new(), write_task: LocalWaker::new(), @@ -147,10 +192,12 @@ impl Io { }; inner.filter.replace(filter_ref); - // start io tasks - io.start(ReadState(inner.clone()), WriteState(inner.clone())); + let io_ref = IoRef(inner); - Io(IoRef(inner), FilterItem::Ptr(Box::into_raw(filter))) + // start io tasks + io.start(ReadState(io_ref.clone()), WriteState(io_ref.clone())); + + Io(io_ref, FilterItem::Ptr(Box::into_raw(filter))) } } @@ -172,7 +219,7 @@ impl Io { #[inline] /// Set io disconnect timeout in secs pub fn set_disconnect_timeout(&self, timeout: Seconds) { - self.0 .0.disconnect_timeout.set(timeout); + self.0 .0.disconnect_timeout.set(timeout.into()); } } @@ -201,6 +248,25 @@ impl Io { pub fn dispatcher_stopped(&self) { self.0 .0.insert_flags(Flags::DSP_STOP); } + + #[inline] + /// Gracefully shutdown read and write io tasks + pub fn init_shutdown(&self, cx: &mut Context<'_>) { + let flags = self.0 .0.flags.get(); + + if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) { + log::trace!("initiate io shutdown {:?}", flags); + self.0 .0.insert_flags(Flags::IO_FILTERS); + if let Err(err) = self.0 .0.shutdown_filters(&self.0) { + self.0 .0.error.set(Some(err)); + self.0 .0.insert_flags(Flags::IO_ERR); + } + + self.0 .0.read_task.wake(); + self.0 .0.write_task.wake(); + self.0 .0.dispatch_task.register(cx.waker()); + } + } } impl IoRef { @@ -220,7 +286,7 @@ impl IoRef { #[inline] /// Check if io error occured in read or write task pub fn is_io_err(&self) -> bool { - self.0.flags.get().contains(Flags::IO_ERR) + self.0.is_io_err() } #[inline] @@ -244,20 +310,6 @@ impl IoRef { .intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::DSP_STOP) } - #[inline] - /// Gracefully shutdown read and write io tasks - pub fn shutdown(&self, cx: &mut Context<'_>) { - let flags = self.0.flags.get(); - - if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { - log::trace!("initiate io shutdown {:?}", flags); - self.0.insert_flags(Flags::IO_SHUTDOWN); - self.0.read_task.wake(); - self.0.write_task.wake(); - self.0.dispatch_task.register(cx.waker()); - } - } - #[inline] /// Take io error if any occured pub fn take_error(&self) -> Option { @@ -331,32 +383,21 @@ impl Io { }; self.0 .0.read_buf.set(buf); - let result = match item { + return match item { Ok(Some(el)) => Ok(Some(el)), Ok(None) => { self.0 .0.remove_flags(Flags::RD_READY); - poll_fn(|cx| { - if read.is_ready() { - Poll::Ready(()) - } else { - read.wake(cx); - Poll::Pending - } - }) - .await; - if self.is_io_err() { - if let Some(err) = self.take_error() { - Err(Either::Right(err)) - } else { - Ok(None) - } - } else { - continue; + if poll_fn(|cx| read.poll_ready(cx)) + .await + .map_err(Either::Right)? + .is_none() + { + return Ok(None); } + continue; } Err(err) => Err(Either::Left(err)), }; - return result; } } @@ -364,8 +405,8 @@ impl Io { /// Encode item, send to a peer pub async fn send( &self, - codec: &U, item: U::Item, + codec: &U, ) -> Result<(), Either> where U: Encoder, @@ -374,31 +415,46 @@ impl Io { let mut buf = filter .get_write_buf() .unwrap_or_else(|| self.0 .0.pool.get().get_write_buf()); + let is_write_sleep = buf.is_empty(); codec.encode(item, &mut buf).map_err(Either::Left)?; - filter.release_write_buf(buf); + filter.release_write_buf(buf).map_err(Either::Right)?; self.0 .0.insert_flags(Flags::WR_WAIT); if is_write_sleep { self.0 .0.write_task.wake(); } - poll_fn(|cx| { - if !self.0 .0.flags.get().contains(Flags::WR_WAIT) || self.is_io_err() { - Poll::Ready(()) - } else { - self.register_dispatcher(cx); - Poll::Pending - } - }) - .await; + poll_fn(|cx| self.write().poll_flush(cx)) + .await + .map_err(Either::Right)?; + Ok(()) + } - if self.is_io_err() { - let err = self.0 .0.error.take().unwrap_or_else(|| { - io::Error::new(io::ErrorKind::Other, "Internal error") - }); - Err(Either::Right(err)) - } else { + #[inline] + /// Shuts down connection + pub async fn shutdown(&self) -> Result<(), io::Error> { + if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { Ok(()) + } else { + poll_fn(|cx| { + let flags = self.flags(); + if !flags.contains(Flags::IO_FILTERS) { + self.init_shutdown(cx); + } + + if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { + if let Some(err) = self.0 .0.error.take() { + Poll::Ready(Err(err)) + } else { + Poll::Ready(Ok(())) + } + } else { + self.0 .0.insert_flags(Flags::IO_FILTERS); + self.0 .0.dispatch_task.register(cx.waker()); + Poll::Pending + } + }) + .await } } @@ -412,26 +468,52 @@ impl Io { where U: Decoder, { - let mut buf = self.0 .0.read_buf.take(); - let item = if let Some(ref mut buf) = buf { - codec.decode(buf) - } else { - Ok(None) - }; - self.0 .0.read_buf.set(buf); + if self + .read() + .poll_ready(cx) + .map_err(Either::Right)? + .is_ready() + { + let mut buf = self.0 .0.read_buf.take(); + let item = if let Some(ref mut buf) = buf { + codec.decode(buf) + } else { + Ok(None) + }; + self.0 .0.read_buf.set(buf); - match item { - Ok(Some(el)) => Poll::Ready(Ok(Some(el))), - Ok(None) => { - self.read().wake(cx); - Poll::Pending + match item { + Ok(Some(el)) => Poll::Ready(Ok(Some(el))), + Ok(None) => { + if let Poll::Ready(res) = + self.read().poll_ready(cx).map_err(Either::Right)? + { + if res.is_none() { + return Poll::Ready(Ok(None)); + } + } + Poll::Pending + } + Err(err) => Poll::Ready(Err(Either::Left(err))), } - Err(err) => Poll::Ready(Err(Either::Left(err))), + } else { + Poll::Pending } } } impl Io { + #[inline] + /// Get referece to filter + pub fn filter(&self) -> &F { + if let FilterItem::Ptr(p) = self.1 { + if let Some(r) = unsafe { p.as_ref() } { + return r; + } + } + panic!() + } + #[inline] pub fn into_boxed(mut self) -> crate::IoBoxed where @@ -457,18 +539,18 @@ impl Io { } #[inline] - pub async fn add_filter(self, factory: &T) -> Result, T::Error> + pub fn add_filter(self, factory: T) -> T::Future where T: FilterFactory, { - factory.create(self).await + factory.create(self) } #[inline] - pub fn map_filter(mut self, map: T) -> Io + pub fn map_filter(mut self, map: U) -> Result, T::Error> where - T: FnOnce(F) -> U, - U: Filter, + T: FilterFactory, + U: FnOnce(F) -> Result, { // replace current filter let filter = unsafe { @@ -477,7 +559,7 @@ impl Io { FilterItem::Boxed(_) => panic!(), FilterItem::Ptr(p) => { assert!(!p.is_null()); - Box::new(map(*Box::from_raw(p))) + Box::new(map(*Box::from_raw(p))?) } }; let filter_ref: &'static dyn Filter = { @@ -488,7 +570,7 @@ impl Io { filter }; - Io(self.0.clone(), FilterItem::Ptr(Box::into_raw(filter))) + Ok(Io(self.0.clone(), FilterItem::Ptr(Box::into_raw(filter)))) } } @@ -562,7 +644,7 @@ impl<'a> WriteRef<'a> { #[inline] /// Get mut access to write buffer - pub fn with_buf(&self, f: F) -> R + pub fn with_buf(&self, f: F) -> Result where F: FnOnce(&mut BytesMut) -> R, { @@ -575,8 +657,8 @@ impl<'a> WriteRef<'a> { } let result = f(&mut buf); - filter.release_write_buf(buf); - result + filter.release_write_buf(buf)?; + Ok(result) } #[inline] @@ -587,7 +669,7 @@ impl<'a> WriteRef<'a> { &self, item: U::Item, codec: &U, - ) -> Result::Error> + ) -> Result::Error, io::Error>> where U: Encoder, { @@ -608,70 +690,45 @@ impl<'a> WriteRef<'a> { } // encode item and wake write task - let result = codec.encode(item, &mut buf).map(|_| { - if is_write_sleep { - self.0.write_task.wake(); - } - buf.len() < hw - }); - filter.release_write_buf(buf); - result + let result = codec + .encode(item, &mut buf) + .map(|_| { + if is_write_sleep { + self.0.write_task.wake(); + } + buf.len() < hw + }) + .map_err(Either::Left); + filter.release_write_buf(buf).map_err(Either::Right)?; + Ok(result?) } else { Ok(true) } } #[inline] - /// Write item to a buf and wake up io task - pub fn encode_result( - &self, - item: Result, E>, - codec: &U, - ) -> Result> - where - U: Encoder, - { - let flags = self.0.flags.get(); + /// Wake write task and instruct to write all data. + /// + /// When write task is done wake dispatcher. + pub fn poll_flush(&self, cx: &mut Context<'_>) -> Poll> { + self.0.insert_flags(Flags::WR_WAIT); - if !flags.intersects(Flags::IO_ERR | Flags::DSP_ERR) { - match item { - Ok(Some(item)) => { - let filter = self.0.filter.get(); - let mut buf = filter - .get_write_buf() - .unwrap_or_else(|| self.0.pool.get().get_write_buf()); - let is_write_sleep = buf.is_empty(); - let (hw, lw) = self.0.pool.get().write_params().unpack(); - - // make sure we've got room - let remaining = buf.capacity() - buf.len(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - // encode item - if let Err(err) = codec.encode(item, &mut buf) { - log::trace!("Encoder error: {:?}", err); - filter.release_write_buf(buf); - self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR); - self.0.dispatch_task.wake(); - return Err(Either::Right(err)); - } else if is_write_sleep { - self.0.write_task.wake(); - } - let result = Ok(buf.len() < hw); - filter.release_write_buf(buf); - result - } - Err(err) => { - self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR); - self.0.dispatch_task.wake(); - Err(Either::Left(err)) - } - _ => Ok(true), + if let Some(buf) = self.0.write_buf.take() { + if !buf.is_empty() { + self.0.write_buf.set(Some(buf)); + self.0.write_task.wake(); + self.0.dispatch_task.register(cx.waker()); + return Poll::Pending; } + } + + if self.0.is_io_err() { + Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| { + io::Error::new(io::ErrorKind::Other, "disconnected") + }))) } else { - Ok(true) + self.0.dispatch_task.register(cx.waker()); + Poll::Ready(Ok(())) } } } @@ -720,28 +777,6 @@ impl<'a> ReadRef<'a> { } } - #[inline] - /// Wake read task and instruct to read more data - /// - /// Only wakes if back-pressure is enabled on read task - /// otherwise read is already awake. - pub fn wake(&self, cx: &mut Context<'_>) { - let mut flags = self.0.flags.get(); - flags.remove(Flags::RD_READY); - if flags.contains(Flags::RD_BUF_FULL) { - log::trace!("read back-pressure is enabled, wake io task"); - flags.remove(Flags::RD_BUF_FULL); - self.0.read_task.wake(); - } - if flags.contains(Flags::RD_PAUSED) { - log::trace!("read is paused, wake io task"); - flags.remove(Flags::RD_PAUSED); - self.0.read_task.wake(); - } - self.0.flags.set(flags); - self.0.dispatch_task.register(cx.waker()); - } - #[inline] /// Attempts to decode a frame from the read buffer. pub fn decode( @@ -753,7 +788,11 @@ impl<'a> ReadRef<'a> { { let mut buf = self.0.read_buf.take(); let result = if let Some(ref mut buf) = buf { - codec.decode(buf) + let result = codec.decode(buf); + if result.as_ref().map(|v| v.is_none()).unwrap_or(false) { + self.0.remove_flags(Flags::RD_READY); + } + result } else { self.0.remove_flags(Flags::RD_READY); Ok(None) @@ -775,12 +814,46 @@ impl<'a> ReadRef<'a> { .unwrap_or_else(|| self.0.pool.get().get_read_buf()); let res = f(&mut buf); if buf.is_empty() { + self.0.remove_flags(Flags::RD_READY); self.0.pool.get().release_read_buf(buf); } else { self.0.read_buf.set(Some(buf)); } res } + + #[inline] + /// Wake read task and instruct to read more data + /// + /// Only wakes if back-pressure is enabled on read task + /// otherwise read is already awake. + pub fn poll_ready( + &self, + cx: &mut Context<'_>, + ) -> Poll, io::Error>> { + let mut flags = self.0.flags.get(); + let ready = flags.contains(Flags::RD_READY); + + if self.0.is_io_err() { + if let Some(err) = self.0.error.take() { + Poll::Ready(Err(err)) + } else { + Poll::Ready(Ok(None)) + } + } else if ready { + Poll::Ready(Ok(Some(()))) + } else { + flags.remove(Flags::RD_READY); + if flags.contains(Flags::RD_BUF_FULL) { + log::trace!("read back-pressure is enabled, wake io task"); + flags.remove(Flags::RD_BUF_FULL); + self.0.read_task.wake(); + } + self.0.flags.set(flags); + self.0.dispatch_task.register(cx.waker()); + Poll::Pending + } + } } /// OnDisconnect future resolves when socket get disconnected @@ -866,6 +939,7 @@ mod tests { #[ntex::test] async fn utils() { + env_logger::init(); let (client, server) = IoTest::create(); client.remote_buffer_cap(1024); client.write(TEXT); @@ -907,14 +981,14 @@ mod tests { client.remote_buffer_cap(1024); let state = Io::new(server); state - .send(&BytesCodec, Bytes::from_static(b"test")) + .send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); let buf = client.read().await.unwrap(); assert_eq!(buf, Bytes::from_static(b"test")); client.write_error(io::Error::new(io::ErrorKind::Other, "err")); - let res = state.send(&BytesCodec, Bytes::from_static(b"test")).await; + let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await; assert!(res.is_err()); assert!(state.flags().contains(Flags::IO_ERR)); assert!(state.flags().contains(Flags::DSP_STOP)); @@ -967,7 +1041,11 @@ mod tests { in_bytes: Rc>, out_bytes: Rc>, } - impl Filter for Counter {} + impl Filter for Counter { + fn shutdown(&self, _: &IoRef) -> Poll> { + Poll::Ready(Ok(())) + } + } impl ReadFilter for Counter { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { @@ -982,9 +1060,13 @@ mod tests { self.inner.get_read_buf() } - fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) { + fn release_read_buf( + &self, + buf: BytesMut, + new_bytes: usize, + ) -> Result<(), io::Error> { self.in_bytes.set(self.in_bytes.get() + new_bytes); - self.inner.release_read_buf(buf, new_bytes); + self.inner.release_read_buf(buf, new_bytes) } } @@ -1009,9 +1091,9 @@ mod tests { } } - fn release_write_buf(&self, buf: BytesMut) { + fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> { self.out_bytes.set(self.out_bytes.get() + buf.len()); - self.inner.release_write_buf(buf); + self.inner.release_write_buf(buf) } } @@ -1023,14 +1105,19 @@ mod tests { type Error = (); type Future = Ready>, Self::Error>; - fn create(&self, st: Io) -> Self::Future { + fn create(self, io: Io) -> Self::Future { let in_bytes = self.0.clone(); let out_bytes = self.1.clone(); - Ready::Ok(st.map_filter(|inner| Counter { - inner, - in_bytes, - out_bytes, - })) + Ready::Ok( + io.map_filter::(|inner| { + Ok(Counter { + inner, + in_bytes, + out_bytes, + }) + }) + .unwrap(), + ) } } @@ -1041,7 +1128,7 @@ mod tests { let factory = CounterFactory(in_bytes.clone(), out_bytes.clone()); let (client, server) = IoTest::create(); - let state = Io::new(server).add_filter(&factory).await.unwrap(); + let state = Io::new(server).add_filter(factory).await.unwrap(); client.remote_buffer_cap(1024); client.write(TEXT); @@ -1049,7 +1136,7 @@ mod tests { assert_eq!(msg, Bytes::from_static(BIN)); state - .send(&BytesCodec, Bytes::from_static(b"test")) + .send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); let buf = client.read().await.unwrap(); @@ -1066,10 +1153,10 @@ mod tests { let (client, server) = IoTest::create(); let state = Io::new(server) - .add_filter(&CounterFactory(in_bytes.clone(), out_bytes.clone())) + .add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone())) .await .unwrap() - .add_filter(&CounterFactory(in_bytes.clone(), out_bytes.clone())) + .add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone())) .await .unwrap(); let state = state.into_boxed(); @@ -1080,7 +1167,7 @@ mod tests { assert_eq!(msg, Bytes::from_static(BIN)); state - .send(&BytesCodec, Bytes::from_static(b"test")) + .send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); let buf = client.read().await.unwrap(); diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index 8795ab8d..4a956a3f 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -1,98 +1,115 @@ -use std::{io, rc::Rc, task::Context, task::Poll}; +use std::{io, task::Context, task::Poll}; use ntex_bytes::{BytesMut, PoolRef}; -use ntex_util::time::Seconds; -use super::{state::Flags, state::IoStateInner, WriteReadiness}; +use super::{state::Flags, IoRef, WriteReadiness}; -pub struct ReadState(pub(super) Rc); +pub struct ReadState(pub(super) IoRef); impl ReadState { #[inline] pub fn memory_pool(&self) -> PoolRef { - self.0.pool.get() + self.0 .0.pool.get() } #[inline] pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.0.filter.get().poll_read_ready(cx) + self.0 .0.filter.get().poll_read_ready(cx) } #[inline] pub fn close(&self, err: Option) { - self.0.filter.get().read_closed(err); + self.0 .0.filter.get().read_closed(err); } #[inline] pub fn get_read_buf(&self) -> BytesMut { self.0 + .0 .filter .get() .get_read_buf() - .unwrap_or_else(|| self.0.pool.get().get_read_buf()) + .unwrap_or_else(|| self.0 .0.pool.get().get_read_buf()) } #[inline] - pub fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) { + pub fn release_read_buf( + &self, + buf: BytesMut, + new_bytes: usize, + ) -> Result<(), io::Error> { if buf.is_empty() { - self.0.pool.get().release_read_buf(buf); + self.0 .0.pool.get().release_read_buf(buf); + Ok(()) } else { - self.0.filter.get().release_read_buf(buf, new_bytes); + let mut flags = self.0 .0.flags.get(); + + // notify dispatcher + if new_bytes > 0 { + flags.insert(Flags::RD_READY); + self.0 .0.flags.set(flags); + self.0 .0.dispatch_task.wake(); + } + self.0 .0.filter.get().release_read_buf(buf, new_bytes)?; + + if flags.contains(Flags::IO_FILTERS) { + self.0 .0.shutdown_filters(&self.0)?; + } + Ok(()) } } } -pub struct WriteState(pub(super) Rc); +pub struct WriteState(pub(super) IoRef); impl WriteState { #[inline] pub fn memory_pool(&self) -> PoolRef { - self.0.pool.get() - } - - #[inline] - pub fn disconnect_timeout(&self) -> Seconds { - self.0.disconnect_timeout.get() + self.0 .0.pool.get() } #[inline] pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.0.filter.get().poll_write_ready(cx) + self.0 .0.filter.get().poll_write_ready(cx) } #[inline] pub fn close(&self, err: Option) { - self.0.filter.get().write_closed(err) + self.0 .0.filter.get().write_closed(err) } #[inline] pub fn get_write_buf(&self) -> Option { - self.0.write_buf.take() + self.0 .0.write_buf.take() } #[inline] - pub fn release_write_buf(&self, buf: BytesMut) { - let pool = self.0.pool.get(); + pub fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> { + let pool = self.0 .0.pool.get(); + let mut flags = self.0 .0.flags.get(); + if buf.is_empty() { pool.release_write_buf(buf); - - let mut flags = self.0.flags.get(); if flags.intersects(Flags::WR_WAIT | Flags::WR_BACKPRESSURE) { flags.remove(Flags::WR_WAIT | Flags::WR_BACKPRESSURE); - self.0.flags.set(flags); - self.0.dispatch_task.wake(); + self.0 .0.flags.set(flags); + self.0 .0.dispatch_task.wake(); } } else { // if write buffer is smaller than high watermark value, turn off back-pressure - if buf.len() < pool.write_params_high() << 1 { - let mut flags = self.0.flags.get(); - if flags.contains(Flags::WR_BACKPRESSURE) { - flags.remove(Flags::WR_BACKPRESSURE); - self.0.flags.set(flags); - self.0.dispatch_task.wake(); - } + if buf.len() < pool.write_params_high() << 1 + && flags.contains(Flags::WR_BACKPRESSURE) + { + flags.remove(Flags::WR_BACKPRESSURE); + self.0 .0.flags.set(flags); + self.0 .0.dispatch_task.wake(); } - self.0.write_buf.set(Some(buf)) + self.0 .0.write_buf.set(Some(buf)) } + + if flags.contains(Flags::IO_FILTERS) { + self.0 .0.shutdown_filters(&self.0)?; + } + Ok(()) } } diff --git a/ntex-io/src/time.rs b/ntex-io/src/time.rs index 6509c52b..de6e6d65 100644 --- a/ntex-io/src/time.rs +++ b/ntex-io/src/time.rs @@ -5,7 +5,7 @@ use std::{ use ntex_util::spawn; use ntex_util::time::{now, sleep, Millis}; -use super::state::{Flags, IoRef, IoStateInner}; +use super::state::{IoRef, IoStateInner}; pub struct Timer(Rc>); @@ -79,8 +79,7 @@ impl Timer { let key = *key; if key <= now_time { for st in i.notifications.remove(&key).unwrap() { - st.dispatch_task.wake(); - st.insert_flags(Flags::DSP_KEEPALIVE); + st.notify_keepalive(); } } else { break; diff --git a/ntex-io/src/tokio_impl.rs b/ntex-io/src/tokio_impl.rs index 884e1a96..81d829ac 100644 --- a/ntex-io/src/tokio_impl.rs +++ b/ntex-io/src/tokio_impl.rs @@ -69,8 +69,13 @@ where Poll::Ready(Ok(n)) => { if n == 0 { log::trace!("io stream is disconnected"); - this.state.release_read_buf(buf, new_bytes); - this.state.close(None); + if let Err(e) = + this.state.release_read_buf(buf, new_bytes) + { + this.state.close(Some(e)); + } else { + this.state.close(None); + } return Poll::Ready(()); } else { new_bytes += n; @@ -81,15 +86,19 @@ where } Poll::Ready(Err(err)) => { log::trace!("read task failed on io {:?}", err); - this.state.release_read_buf(buf, new_bytes); + let _ = this.state.release_read_buf(buf, new_bytes); this.state.close(Some(err)); return Poll::Ready(()); } } } - this.state.release_read_buf(buf, new_bytes); - Poll::Pending + if let Err(e) = this.state.release_read_buf(buf, new_bytes) { + this.state.close(Some(e)); + Poll::Ready(()) + } else { + Poll::Pending + } } Poll::Pending => Poll::Pending, } @@ -98,8 +107,8 @@ where #[derive(Debug)] enum IoWriteState { - Processing, - Shutdown(Option, Shutdown), + Processing(Option), + Shutdown(Sleep, Shutdown), } #[derive(Debug)] @@ -125,7 +134,7 @@ where Self { io, state, - st: IoWriteState::Processing, + st: IoWriteState::Processing(None), } } } @@ -140,22 +149,41 @@ where let mut this = self.as_mut().get_mut(); match this.st { - IoWriteState::Processing => { + IoWriteState::Processing(ref mut delay) => { match this.state.poll_ready(cx) { Poll::Ready(Ok(())) => { + if let Some(delay) = delay { + if delay.poll_elapsed(cx).is_ready() { + this.state.close(Some(io::Error::new( + io::ErrorKind::TimedOut, + "Operation timedout", + ))); + return Poll::Ready(()); + } + } + // flush framed instance match flush_io(&mut *this.io.borrow_mut(), &this.state, cx) { Poll::Pending | Poll::Ready(true) => Poll::Pending, Poll::Ready(false) => Poll::Ready(()), } } - Poll::Ready(Err(WriteReadiness::Shutdown)) => { + Poll::Ready(Err(WriteReadiness::Timeout(time))) => { + if delay.is_none() { + *delay = Some(sleep(time)); + } + self.poll(cx) + } + Poll::Ready(Err(WriteReadiness::Shutdown(time))) => { log::trace!("write task is instructed to shutdown"); - this.st = IoWriteState::Shutdown( - this.state.disconnect_timeout().map(sleep), - Shutdown::None, - ); + let timeout = if let Some(delay) = delay.take() { + delay + } else { + sleep(time) + }; + + this.st = IoWriteState::Shutdown(timeout, Shutdown::None); self.poll(cx) } Poll::Ready(Err(WriteReadiness::Terminate)) => { @@ -229,10 +257,8 @@ where } // disconnect timeout - if let Some(ref delay) = delay { - if delay.poll_elapsed(cx).is_pending() { - return Poll::Pending; - } + if delay.poll_elapsed(cx).is_pending() { + return Poll::Pending; } log::trace!("write task is stopped after delay"); this.state.close(None); @@ -290,11 +316,17 @@ pub(super) fn flush_io( // remove written data let result = if written == len { buf.clear(); - state.release_write_buf(buf); + if let Err(e) = state.release_write_buf(buf) { + state.close(Some(e)); + return Poll::Ready(false); + } Poll::Ready(true) } else { buf.advance(written); - state.release_write_buf(buf); + if let Err(e) = state.release_write_buf(buf) { + state.close(Some(e)); + return Poll::Ready(false); + } Poll::Pending }; diff --git a/ntex-io/src/utils.rs b/ntex-io/src/utils.rs index c1c321d5..be7eea44 100644 --- a/ntex-io/src/utils.rs +++ b/ntex-io/src/utils.rs @@ -1,6 +1,9 @@ -use ntex_service::{fn_factory_with_config, into_service, Service, ServiceFactory}; +use std::{io, marker::PhantomData, task::Context, task::Poll}; -use super::{Filter, Io, IoBoxed, IoStream}; +use ntex_service::{fn_factory_with_config, into_service, Service, ServiceFactory}; +use ntex_util::future::Ready; + +use super::{Filter, FilterFactory, Io, IoBoxed, IoStream}; /// Service that converts any Io stream to IoBoxed stream pub fn into_boxed( @@ -47,3 +50,81 @@ where } }) } + +/// Service that converts IoStream stream to Io stream +pub fn into_io() -> impl ServiceFactory< + Config = (), + Request = I, + Response = Io, + Error = io::Error, + InitError = (), +> +where + I: IoStream, +{ + fn_factory_with_config(move |_: ()| { + Ready::Ok(into_service(move |io| Ready::Ok(Io::new(io)))) + }) +} + +/// Create filter factory service +pub fn filter_factory(filter: T) -> FilterServiceFactory +where + T: FilterFactory + Clone, + F: Filter, +{ + FilterServiceFactory { + filter, + _t: PhantomData, + } +} + +pub struct FilterServiceFactory { + filter: T, + _t: PhantomData, +} + +impl ServiceFactory for FilterServiceFactory +where + T: FilterFactory + Clone, + F: Filter, +{ + type Config = (); + type Request = Io; + type Response = Io; + type Error = T::Error; + type Service = FilterService; + type InitError = (); + type Future = Ready; + + fn new_service(&self, _: ()) -> Self::Future { + Ready::Ok(FilterService { + filter: self.filter.clone(), + _t: PhantomData, + }) + } +} + +pub struct FilterService { + filter: T, + _t: PhantomData, +} + +impl Service for FilterService +where + T: FilterFactory + Clone, + F: Filter, +{ + type Request = Io; + type Response = Io; + type Error = T::Error; + type Future = T::Future; + + fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&self, req: Io) -> Self::Future { + req.add_filter(self.filter.clone()) + } +} diff --git a/ntex-openssl/Cargo.toml b/ntex-openssl/Cargo.toml new file mode 100644 index 00000000..2c11be8c --- /dev/null +++ b/ntex-openssl/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "ntex-openssl" +version = "0.1.0" +authors = ["ntex contributors "] +description = "An implementation of SSL streams for ntex backed by OpenSSL" +keywords = ["network", "framework", "async", "futures"] +homepage = "https://ntex.rs" +repository = "https://github.com/ntex-rs/ntex.git" +documentation = "https://docs.rs/ntex-openssl/" +categories = ["network-programming", "asynchronous"] +license = "MIT" +edition = "2018" + +[lib] +name = "ntex_openssl" +path = "src/lib.rs" + +[dependencies] +ntex-bytes = "0.1.7" +ntex-io = "0.1.0" +ntex-util = "0.1.2" +openssl = "0.10.32" + +[dev-dependencies] +ntex = { version = "0.4.14", features = ["openssl"] } +futures = "0.3" +env_logger = "0.9" diff --git a/ntex-openssl/LICENSE b/ntex-openssl/LICENSE new file mode 120000 index 00000000..ea5b6064 --- /dev/null +++ b/ntex-openssl/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/ntex-openssl/examples/cert.pem b/ntex-openssl/examples/cert.pem new file mode 100644 index 00000000..9a744d16 --- /dev/null +++ b/ntex-openssl/examples/cert.pem @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE----- +MIICljCCAX4CCQDztMNlxk6oeTANBgkqhkiG9w0BAQsFADANMQswCQYDVQQIDAJj +YTAeFw0xOTAzMDcwNzEyNThaFw0yMDAzMDYwNzEyNThaMA0xCzAJBgNVBAgMAmNh +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0GMP3YzDVFWgNhRiHnfe +d192131Zi23p8WiutneD9I5WO42c79fOXsxLWn+2HSqPvCPHIBLoMX8o9lgCxt2P +/JUCAWbrE2EuvhkMrWk6/q7xB211XZYfnkqdt7mA0jMUC5o32AX3ew456TAq5P8Y +dq9H/qXdRtAvKD0QdkFfq8ePCiqOhcqacZ/NWva7R4HdgTnbL1DRQjGBXszI07P9 +1yw8GOym46uxNHRujQp3lYEhc1V3JTF9kETpSBHyEAkQ8WHxGf8UBHDhh7hcc+KI +JHMlVYy5wDv4ZJeYsY1rD6/n4tyd3r0yzBM57UGf6qrVZEYmLB7Jad+8Df5vIoGh +WwIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQB1DEu9NiShCfQuA17MG5O0Jr2/PS1z +/+HW7oW15WXpqDKOEJalid31/Bzwvwq0bE12xKE4ZLdbqJHmJTdSUoGfOfBZKka6 +R2thOjqH7hFvxjfgS7kBy5BrRZewM9xKIJ6zU6+6mxR64x9vmkOmppV0fx5clZjH +c7qn5kSNWTMsFbjPnb5BeJJwZdqpMLs99jgoMvGtCUmkyVYODGhh65g6tR9kIPvM +zu/Cw122/y7tFfkuknMSYwGEYF3XcZpXt54a6Lu5hk6PuOTsK+7lC+HX7CSF1dpv +u1szL5fDgiCBFCnyKeOqF61mxTCUht3U++37VDFvhzN1t6HIVTYm2JJ7 +-----END CERTIFICATE----- diff --git a/ntex-openssl/examples/client.rs b/ntex-openssl/examples/client.rs new file mode 100644 index 00000000..e9f6bfdd --- /dev/null +++ b/ntex-openssl/examples/client.rs @@ -0,0 +1,35 @@ +use std::io; + +use ntex::{codec, connect, util::Bytes, util::Either}; +use openssl::ssl::{self, SslMethod, SslVerifyMode}; + +#[ntex::main] +async fn main() -> io::Result<()> { + std::env::set_var("RUST_LOG", "trace"); + env_logger::init(); + + println!("Connecting to openssl server: 127.0.0.1:8443"); + + // load ssl keys + let mut builder = ssl::SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let connector = builder.build(); + + // start server + let connector = connect::openssl::IoConnector::new(connector); + + let io = connector.connect("127.0.0.1:8443").await.unwrap(); + println!("Connected to ssl server"); + let result = io + .send(Bytes::from_static(b"hello"), &codec::BytesCodec) + .await + .map_err(Either::into_inner)?; + + let resp = io + .next(&codec::BytesCodec) + .await + .map_err(Either::into_inner)?; + + println!("disconnecting"); + io.shutdown().await +} diff --git a/ntex-openssl/examples/key.pem b/ntex-openssl/examples/key.pem new file mode 100644 index 00000000..4416facc --- /dev/null +++ b/ntex-openssl/examples/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDQYw/djMNUVaA2 +FGIed953X3bXfVmLbenxaK62d4P0jlY7jZzv185ezEtaf7YdKo+8I8cgEugxfyj2 +WALG3Y/8lQIBZusTYS6+GQytaTr+rvEHbXVdlh+eSp23uYDSMxQLmjfYBfd7Djnp +MCrk/xh2r0f+pd1G0C8oPRB2QV+rx48KKo6Fyppxn81a9rtHgd2BOdsvUNFCMYFe +zMjTs/3XLDwY7Kbjq7E0dG6NCneVgSFzVXclMX2QROlIEfIQCRDxYfEZ/xQEcOGH +uFxz4ogkcyVVjLnAO/hkl5ixjWsPr+fi3J3evTLMEzntQZ/qqtVkRiYsHslp37wN +/m8igaFbAgMBAAECggEAJI278rkGany6pcHdlEqik34DcrliQ7r8FoSuYQOF+hgd +uESXCttoL+jWLwHICEW3AOGlxFKMuGH95Xh6xDeJUl0xBN3wzm11rZLnTmPvHU3C +qfLha5Ex6qpcECZSGo0rLv3WXeZuCv/r2KPCYnj86ZTFpD2kGw/Ztc1AXf4Jsi/1 +478Mf23QmAvCAPimGCyjLQx2c9/vg/6K7WnDevY4tDuDKLeSJxKZBSHUn3cM1Bwj +2QzaHfSFA5XljOF5PLeR3cY5ncrrVLWChT9XuGt9YMdLAcSQxgE6kWV1RSCq+lbj +e6OOe879IrrqwBvMQfKQqnm1kl8OrfPMT5CNWKvEgQKBgQD8q5E4x9taDS9RmhRO +07ptsr/I795tX8CaJd/jc4xGuCGBqpNw/hVebyNNYQvpiYzDNBSEhtd59957VyET +hcrGyxD0ByKm8F/lPgFw5y6wi3RUnucCV/jxkMHmxVzYMbFUEGCQ0pIU9/GFS7RZ +9VjqRDeE86U3yHO+WCFoHtd8aQKBgQDTIhi0uq0oY87bUGnWbrrkR0UVRNPDG1BT +cuXACYlv/DV/XpxPC8iPK1UwG4XaOVxodtIRjdBqvb8fUM6HSY6qll64N/4/1jre +Ho+d4clE4tK6a9WU96CKxwHn2BrWUZJPtoldaCZJFJ7SfiHuLlqW7TtYFrOfPIjN +ADiqK+bHIwKBgQCpfIiAVwebo0Z/bWR77+iZFxMwvT4tjdJLVGaXUvXgpjjLmtkm +LTm2S8SZbiSodfz3H+M3dp/pj8wsXiiwyMlZifOITZT/+DPLOUmMK3cVM6ZH8QMy +fkJd/+UhYHhECSlTI10zKByXdi4LZNnIkhwfoLzBMRI9lfeV0dYu2qlfKQKBgEVI +kRbtk1kHt5/ceX62g3nZsV/TYDJMSkW4FJC6EHHBL8UGRQDjewMQUzogLgJ4hEx7 +gV/lS5lbftZF7CAVEU4FXjvRlAtav6KYIMTMjQGf9UrbjBEAWZxwxb1Q+y2NQxgJ +bHZMcRPWQnAMmBHTAEM6whicCoGcmb+77Nxa37ZFAoGBALBuUNeD3fKvQR8v6GoA +spv+RYL9TB4wz2Oe9EYSp9z5EiWlTmuvFz3zk8pHDSpntxYH5O5HJ/3OzwhHz9ym ++DNE9AP9LW9hAzMuu7Gob1h8ShGwJVYwrQN3q/83ooUL7WSAuVOLpzJ7BFFlcCjp +MhFvd9iOt/R0N30/3AbQXkOp +-----END PRIVATE KEY----- diff --git a/ntex-openssl/examples/server.rs b/ntex-openssl/examples/server.rs new file mode 100644 index 00000000..fde67e62 --- /dev/null +++ b/ntex-openssl/examples/server.rs @@ -0,0 +1,54 @@ +use std::io; + +use ntex::service::{fn_service, pipeline_factory}; +use ntex::{codec, io::filter_factory, io::into_io, io::Io, server, util::Either}; +use ntex_openssl::SslAcceptor; +use openssl::ssl::{self, SslFiletype, SslMethod}; + +#[ntex::main] +async fn main() -> io::Result<()> { + std::env::set_var("RUST_LOG", "trace"); + env_logger::init(); + + println!("Started openssl echp server: 127.0.0.1:8443"); + + // load ssl keys + let mut builder = ssl::SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); + builder + .set_private_key_file("../tests/key.pem", SslFiletype::PEM) + .unwrap(); + builder + .set_certificate_chain_file("../tests/cert.pem") + .unwrap(); + let acceptor = builder.build(); + + // start server + server::ServerBuilder::new() + .bind("basic", "127.0.0.1:8443", move || { + pipeline_factory(into_io()) + .and_then(filter_factory(SslAcceptor::new(acceptor.clone()))) + .and_then(fn_service(|io: Io<_>| async move { + println!("New client is connected"); + loop { + match io.next(&codec::BytesCodec).await { + Ok(Some(msg)) => { + println!("Got message: {:?}", msg); + io.send(msg.freeze(), &codec::BytesCodec) + .await + .map_err(Either::into_inner)?; + } + Ok(None) => break, + Err(e) => { + println!("Got error: {:?}", e); + break; + } + } + } + println!("Client is disconnected"); + Ok(()) + })) + })? + .workers(1) + .run() + .await +} diff --git a/ntex-openssl/src/lib.rs b/ntex-openssl/src/lib.rs new file mode 100644 index 00000000..0becd770 --- /dev/null +++ b/ntex-openssl/src/lib.rs @@ -0,0 +1,320 @@ +#![allow(clippy::type_complexity)] +//! An implementation of SSL streams for ntex backed by OpenSSL +use std::cell::RefCell; +use std::{cmp, error::Error, future::Future, io, pin::Pin, task::Context, task::Poll}; + +use ntex_bytes::{BufMut, BytesMut}; +use ntex_io::{ + Filter, FilterFactory, Io, IoRef, ReadFilter, WriteFilter, WriteReadiness, +}; +use ntex_util::{future::poll_fn, time, time::Millis}; +use openssl::ssl::{self, SslStream}; + +pub struct SslFilter { + inner: RefCell>>, +} + +struct IoInner { + inner: F, + read_buf: Option, + write_buf: Option, +} + +impl io::Read for IoInner { + fn read(&mut self, dst: &mut [u8]) -> io::Result { + if let Some(ref mut buf) = self.read_buf { + if buf.is_empty() { + buf.clear(); + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } else { + let len = cmp::min(buf.len(), dst.len()); + dst.copy_from_slice(&buf.split_to(len)); + Ok(len) + } + } else { + Err(io::Error::from(io::ErrorKind::WouldBlock)) + } + } +} + +impl io::Write for IoInner { + fn write(&mut self, src: &[u8]) -> io::Result { + let mut buf = if let Some(mut buf) = self.inner.get_write_buf() { + buf.reserve(buf.len()); + buf + } else { + BytesMut::with_capacity(src.len()) + }; + buf.extend_from_slice(src); + self.inner.release_write_buf(buf)?; + Ok(src.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl Filter for SslFilter { + fn shutdown(&self, st: &IoRef) -> Poll> { + let ssl_result = self.inner.borrow_mut().shutdown(); + match ssl_result { + Ok(ssl::ShutdownResult::Sent) => Poll::Pending, + Ok(ssl::ShutdownResult::Received) => { + self.inner.borrow().get_ref().inner.shutdown(st) + } + Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Poll::Ready(Ok(())), + Err(ref e) + if e.code() == ssl::ErrorCode::WANT_READ + || e.code() == ssl::ErrorCode::WANT_WRITE => + { + Poll::Pending + } + Err(e) => Poll::Ready(Err(e + .into_io_error() + .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))), + } + } +} + +impl ReadFilter for SslFilter { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.inner.borrow().get_ref().inner.poll_read_ready(cx) + } + + fn read_closed(&self, err: Option) { + self.inner.borrow().get_ref().inner.read_closed(err) + } + + fn get_read_buf(&self) -> Option { + if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() { + if !buf.is_empty() { + return Some(buf); + } + } + None + } + + fn release_read_buf( + &self, + src: BytesMut, + new_bytes: usize, + ) -> Result<(), io::Error> { + // store to read_buf + self.inner.borrow_mut().get_mut().read_buf = Some(src); + if new_bytes == 0 { + return Ok(()); + } + + let mut buf = + if let Some(buf) = self.inner.borrow().get_ref().inner.get_read_buf() { + buf + } else { + BytesMut::with_capacity(4096) + }; + + let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *buf.chunk_mut()) }; + let ssl_result = self.inner.borrow_mut().ssl_read(chunk); + let result = match ssl_result { + Ok(v) => { + unsafe { buf.advance_mut(v) }; + self.inner + .borrow() + .get_ref() + .inner + .release_read_buf(buf, v)?; + Ok(()) + } + Err(e) => match e.code() { + ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()), + _ => (Err(map_to_ioerr(e))), + }, + }; + result + } +} + +impl WriteFilter for SslFilter { + fn poll_write_ready( + &self, + cx: &mut Context<'_>, + ) -> Poll> { + self.inner.borrow().get_ref().inner.poll_write_ready(cx) + } + + fn write_closed(&self, err: Option) { + self.inner.borrow().get_ref().inner.read_closed(err) + } + + fn get_write_buf(&self) -> Option { + if let Some(buf) = self.inner.borrow_mut().get_mut().write_buf.take() { + if !buf.is_empty() { + return Some(buf); + } + } + None + } + + fn release_write_buf(&self, mut buf: BytesMut) -> Result<(), io::Error> { + let ssl_result = self.inner.borrow_mut().ssl_write(&buf); + let result = match ssl_result { + Ok(v) => { + if v != buf.len() { + buf.split_to(v); + self.inner.borrow_mut().get_mut().write_buf = Some(buf); + } + Ok(()) + } + Err(e) => match e.code() { + ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()), + _ => (Err(map_to_ioerr(e))), + }, + }; + result + } +} + +pub struct SslAcceptor { + acceptor: ssl::SslAcceptor, + timeout: Millis, +} + +impl SslAcceptor { + /// Create openssl acceptor filter factory + pub fn new(acceptor: ssl::SslAcceptor) -> Self { + SslAcceptor { + acceptor, + timeout: Millis(5_000), + } + } + + /// Set handshake timeout. + /// + /// Default is set to 5 seconds. + pub fn timeout>(mut self, timeout: U) -> Self { + self.timeout = timeout.into(); + self + } +} + +impl Clone for SslAcceptor { + fn clone(&self) -> Self { + Self { + acceptor: self.acceptor.clone(), + timeout: self.timeout, + } + } +} + +impl FilterFactory for SslAcceptor { + type Filter = SslFilter; + + type Error = io::Error; + type Future = Pin, Self::Error>>>>; + + fn create(self, st: Io) -> Self::Future { + let timeout = self.timeout; + let ctx_result = ssl::Ssl::new(self.acceptor.context()); + + Box::pin(async move { + time::timeout(timeout, async { + let ssl = ctx_result.map_err(map_to_ioerr)?; + let st = st.map_filter::(|inner: F| { + let inner = IoInner { + inner, + read_buf: None, + write_buf: None, + }; + let ssl_stream = + ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?; + + Ok(SslFilter { + inner: RefCell::new(ssl_stream), + }) + })?; + + poll_fn(|cx| { + let _ = st.write().poll_flush(cx)?; + handle_result(st.filter().inner.borrow_mut().accept(), &st, cx) + .map_err(map_to_ioerr) + }) + .await?; + + Ok(st) + }) + .await + .map_err(|_| { + io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout") + }) + .and_then(|item| item) + }) + } +} + +pub struct SslConnector { + ssl: ssl::Ssl, +} + +impl SslConnector { + /// Create openssl connector filter factory + pub fn new(ssl: ssl::Ssl) -> Self { + SslConnector { ssl } + } +} + +impl FilterFactory for SslConnector { + type Filter = SslFilter; + + type Error = io::Error; + type Future = Pin, Self::Error>>>>; + + fn create(self, st: Io) -> Self::Future { + Box::pin(async move { + let ssl = self.ssl; + let st = st.map_filter::(|inner: F| { + let inner = IoInner { + inner, + read_buf: None, + write_buf: None, + }; + let ssl_stream = + ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?; + + Ok(SslFilter { + inner: RefCell::new(ssl_stream), + }) + })?; + + poll_fn(|cx| { + let _ = st.write().poll_flush(cx)?; + handle_result(st.filter().inner.borrow_mut().connect(), &st, cx) + .map_err(map_to_ioerr) + }) + .await?; + + Ok(st) + }) + } +} + +fn handle_result( + result: Result, + st: &IoRef, + cx: &mut Context<'_>, +) -> Poll> { + match result { + Ok(v) => Poll::Ready(Ok(v)), + Err(e) => match e.code() { + ssl::ErrorCode::WANT_READ => { + let _ = st.read().poll_ready(cx); + Poll::Pending + } + ssl::ErrorCode::WANT_WRITE => Poll::Pending, + _ => Poll::Ready(Err(e)), + }, + } +} + +fn map_to_ioerr>>(err: E) -> io::Error { + io::Error::new(io::ErrorKind::Other, err) +} diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 44915d73..8cb36f2c 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -24,7 +24,7 @@ path = "src/lib.rs" default = ["http-framework"] # openssl -openssl = ["open-ssl", "tokio-openssl"] +openssl = ["open-ssl", "tokio-openssl", "ntex-openssl"] # rustls support rustls = ["rust-tls", "rustls-pemfile", "tokio-rustls", "webpki", "webpki-roots"] @@ -51,6 +51,7 @@ ntex-macros = "0.1.3" ntex-util = "0.1.2" ntex-bytes = "0.1.7" ntex-io = { version = "0.1", features = ["tokio"] } +ntex-openssl = { version = "0.1", optional = true } base64 = "0.13" bitflags = "1.3" diff --git a/ntex/src/connect/openssl.rs b/ntex/src/connect/openssl.rs index c46b09fe..7e7c1c39 100644 --- a/ntex/src/connect/openssl.rs +++ b/ntex/src/connect/openssl.rs @@ -1,13 +1,15 @@ use std::{future::Future, io, pin::Pin, task::Context, task::Poll}; +use ntex_openssl::{SslConnector as IoSslConnector, SslFilter}; pub use open_ssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; pub use tokio_openssl::SslStream; +use crate::io::{DefaultFilter, Io}; use crate::rt::net::TcpStream; use crate::service::{Service, ServiceFactory}; use crate::util::Ready; -use super::{Address, Connect, ConnectError, Connector}; +use super::{Address, Connect, ConnectError, Connector, IoConnector as BaseIoConnector}; pub struct OpensslConnector { connector: Connector, @@ -106,6 +108,101 @@ impl Service for OpensslConnector { } } +pub struct IoConnector { + connector: BaseIoConnector, + openssl: SslConnector, +} + +impl IoConnector { + /// Construct new OpensslConnectService factory + pub fn new(connector: SslConnector) -> Self { + IoConnector { + connector: BaseIoConnector::default(), + openssl: connector, + } + } +} + +impl IoConnector { + /// Resolve and connect to remote host + pub fn connect( + &self, + message: U, + ) -> impl Future>, ConnectError>> + where + Connect: From, + { + let message = Connect::from(message); + let host = message.host().to_string(); + let conn = self.connector.call(message); + let openssl = self.openssl.clone(); + + async move { + let io = conn.await?; + trace!("SSL Handshake start for: {:?}", host); + + match openssl.configure() { + Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()), + Ok(config) => { + let ssl = config + .into_ssl(&host) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + match io.add_filter(IoSslConnector::new(ssl)).await { + Ok(io) => { + trace!("SSL Handshake success: {:?}", host); + Ok(io) + } + Err(e) => { + trace!("SSL Handshake error: {:?}", e); + Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)) + .into()) + } + } + } + } + } + } +} + +impl Clone for IoConnector { + fn clone(&self) -> Self { + IoConnector { + connector: self.connector.clone(), + openssl: self.openssl.clone(), + } + } +} + +impl ServiceFactory for IoConnector { + type Request = Connect; + type Response = Io>; + type Error = ConnectError; + type Config = (); + type Service = IoConnector; + type InitError = (); + type Future = Ready; + + fn new_service(&self, _: ()) -> Self::Future { + Ready::Ok(self.clone()) + } +} + +impl Service for IoConnector { + type Request = Connect; + type Response = Io>; + type Error = ConnectError; + type Future = Pin>>>; + + #[inline] + fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&self, req: Connect) -> Self::Future { + Box::pin(self.connect(req)) + } +} + #[cfg(test)] mod tests { use super::*;