diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index 9a9ae0b6..6b2d4c0d 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -1,8 +1,8 @@ # Changes -## [0.3.4] - 2023-11-xx - +## [0.3.4] - 2023-11-03 +* Add Io::force_ready_ready() and Io::poll_force_ready_ready() methods ## [0.3.3] - 2023-09-11 diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index 62faaefe..ec13e945 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -31,6 +31,8 @@ bitflags::bitflags! { const RD_READY = 0b0000_0000_0010_0000; /// read buffer is full const RD_BUF_FULL = 0b0000_0000_0100_0000; + /// any new data is available + const RD_FORCE_READY = 0b0000_0000_1000_0000; /// wait write completion const WR_WAIT = 0b0000_0001_0000_0000; @@ -78,10 +80,15 @@ impl IoState { self.flags.set(flags); } - pub(super) fn remove_flags(&self, f: Flags) { + pub(super) fn remove_flags(&self, f: Flags) -> bool { let mut flags = self.flags.get(); - flags.remove(f); - self.flags.set(flags); + if flags.intersects(f) { + flags.remove(f); + self.flags.set(flags); + true + } else { + false + } } pub(super) fn notify_keepalive(&self) { @@ -365,6 +372,13 @@ impl Io { poll_fn(|cx| self.poll_read_ready(cx)).await } + #[doc(hidden)] + #[inline] + /// Wait until read becomes ready. + pub async fn force_read_ready(&self) -> io::Result> { + poll_fn(|cx| self.poll_force_read_ready(cx)).await + } + #[inline] /// Pause read task pub fn pause(&self) { @@ -455,6 +469,39 @@ impl Io { } } + #[doc(hidden)] + #[inline] + /// Polls for read readiness. + /// + /// If the io stream is not currently ready for reading, + /// this method will store a clone of the Waker from the provided Context. + /// When the io stream becomes ready for reading, Waker::wake will be called on the waker. + /// + /// Return value + /// The function returns: + /// + /// `Poll::Pending` if the io stream is not ready for reading. + /// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading. + /// `Poll::Ready(Ok(None))` if io stream is disconnected + /// `Some(Poll::Ready(Err(e)))` if an error is encountered. + pub fn poll_force_read_ready( + &self, + cx: &mut Context<'_>, + ) -> Poll>> { + let ready = self.poll_read_ready(cx); + + if ready.is_pending() { + if self.0 .0.remove_flags(Flags::RD_FORCE_READY) { + Poll::Ready(Ok(Some(()))) + } else { + self.0 .0.insert_flags(Flags::RD_FORCE_READY); + Poll::Pending + } + } else { + ready + } + } + #[inline] /// Decode codec item from incoming bytes stream. /// diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index 7417d27f..8c6edf17 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -61,6 +61,10 @@ impl ReadContext { // so we need to wake up read task to read more data // otherwise read task would sleep forever inner.read_task.wake(); + } else if inner.flags.get().contains(Flags::RD_FORCE_READY) { + // in case of "force read" we must wake up dispatch task + // if we read any data from source + inner.dispatch_task.wake(); } // while reading, filter wrote some data diff --git a/ntex-tls/CHANGES.md b/ntex-tls/CHANGES.md index 2f643bc8..3ab959af 100644 --- a/ntex-tls/CHANGES.md +++ b/ntex-tls/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.3.2] - 2023-11-03 + +* Improve implementation + ## [0.3.1] - 2023-09-11 * Add missing fmt::Debug impls diff --git a/ntex-tls/Cargo.toml b/ntex-tls/Cargo.toml index 4daab9e6..1793e596 100644 --- a/ntex-tls/Cargo.toml +++ b/ntex-tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-tls" -version = "0.3.1" +version = "0.3.2" authors = ["ntex contributors "] description = "An implementation of SSL streams for ntex backed by OpenSSL" keywords = ["network", "framework", "async", "futures"] @@ -26,9 +26,9 @@ rustls = ["tls_rust"] [dependencies] ntex-bytes = "0.1.19" -ntex-io = "0.3.3" -ntex-util = "0.3.2" -ntex-service = "1.2.6" +ntex-io = "0.3.4" +ntex-util = "0.3.3" +ntex-service = "1.2.7" log = "0.4" pin-project-lite = "0.2" diff --git a/ntex-tls/src/openssl/mod.rs b/ntex-tls/src/openssl/mod.rs index b9dbf1ab..81b20984 100644 --- a/ntex-tls/src/openssl/mod.rs +++ b/ntex-tls/src/openssl/mod.rs @@ -1,10 +1,10 @@ //! An implementation of SSL streams for ntex backed by OpenSSL -use std::cell::{Cell, RefCell}; -use std::{any, cmp, error::Error, fmt, io, task::Context, task::Poll}; +use std::cell::RefCell; +use std::{any, cmp, error::Error, fmt, io, task::Poll}; use ntex_bytes::{BufMut, BytesVec}; use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf}; -use ntex_util::{future::poll_fn, future::BoxFuture, ready, time, time::Millis}; +use ntex_util::{future::BoxFuture, time, time::Millis}; use tls_openssl::ssl::{self, NameType, SslStream}; use tls_openssl::x509::X509; @@ -25,7 +25,6 @@ pub struct PeerCertChain(pub Vec); #[derive(Debug)] pub struct SslFilter { inner: RefCell>, - handshake: Cell, } #[derive(Debug)] @@ -147,7 +146,7 @@ impl FilterLayer for SslFilter { buf.with_write_buf(|b| { self.with_buffers(b, || { buf.with_dst(|dst| { - let mut new_bytes = usize::from(self.handshake.get()); + let mut new_bytes = 0; loop { buf.resize_buf(dst); @@ -270,27 +269,21 @@ impl FilterFactory for SslAcceptor { destination: None, }; let filter = SslFilter { - handshake: Cell::new(true), inner: RefCell::new(ssl::SslStream::new(ssl, inner)?), }; let io = io.add_filter(filter); - poll_fn(|cx| { - let result = io - .with_buf(|buf| { - let filter = io.filter(); - filter.with_buffers(buf, || filter.inner.borrow_mut().accept()) - }) - .map_err(|err| { - let err: Box = - io::Error::new(io::ErrorKind::Other, err).into(); - err - })?; - handle_result(result, &io, cx) - }) - .await?; + log::debug!("Accepting tls connection"); + loop { + let result = io.with_buf(|buf| { + let filter = io.filter(); + filter.with_buffers(buf, || filter.inner.borrow_mut().accept()) + })?; + if handle_result(&io, result).await?.is_some() { + break; + } + } - io.filter().handshake.set(false); Ok(io) }) .await @@ -327,55 +320,41 @@ impl FilterFactory for SslConnector { destination: None, }; let filter = SslFilter { - handshake: Cell::new(true), inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?), }; let io = io.add_filter(filter); - poll_fn(|cx| { - let result = io - .with_buf(|buf| { - let filter = io.filter(); - filter.with_buffers(buf, || filter.inner.borrow_mut().connect()) - }) - .map_err(|err| { - let err: Box = - io::Error::new(io::ErrorKind::Other, err).into(); - err - })?; - handle_result(result, &io, cx) - }) - .await?; + loop { + let result = io.with_buf(|buf| { + let filter = io.filter(); + filter.with_buffers(buf, || filter.inner.borrow_mut().connect()) + })?; + if handle_result(&io, result).await?.is_some() { + break; + } + } - io.filter().handshake.set(false); Ok(io) }) } } -fn handle_result( - result: Result, +async fn handle_result( io: &Io, - cx: &mut Context<'_>, -) -> Poll>> { + result: Result, +) -> io::Result> { match result { - Ok(v) => Poll::Ready(Ok(v)), + Ok(v) => Ok(Some(v)), Err(e) => match e.code() { ssl::ErrorCode::WANT_READ => { - match ready!(io.poll_read_ready(cx)) { - Ok(None) => Err::<_, Box>( - io::Error::new(io::ErrorKind::Other, "disconnected").into(), - ), - Err(err) => Err(err.into()), - _ => Ok(()), - }?; - Poll::Pending + let res = io.force_read_ready().await; + match res? { + None => Err(io::Error::new(io::ErrorKind::Other, "disconnected")), + _ => Ok(None), + } } - ssl::ErrorCode::WANT_WRITE => { - let _ = io.poll_flush(cx, true)?; - Poll::Pending - } - _ => Poll::Ready(Err(Box::new(e))), + ssl::ErrorCode::WANT_WRITE => Ok(None), + _ => Err(io::Error::new(io::ErrorKind::Other, e)), }, } } diff --git a/ntex-tls/src/rustls/client.rs b/ntex-tls/src/rustls/client.rs index a00b3b42..6316c3ce 100644 --- a/ntex-tls/src/rustls/client.rs +++ b/ntex-tls/src/rustls/client.rs @@ -1,20 +1,19 @@ //! An implementation of SSL streams for ntex backed by OpenSSL use std::io::{self, Read as IoRead, Write as IoWrite}; -use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll}; +use std::{any, cell::RefCell, sync::Arc, task::Poll}; use ntex_bytes::BufMut; use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf}; use ntex_util::{future::poll_fn, ready}; use tls_rust::{ClientConfig, ClientConnection, ServerName}; -use crate::rustls::{IoInner, TlsFilter, Wrapper}; +use crate::rustls::{TlsFilter, Wrapper}; use super::{PeerCert, PeerCertChain}; #[derive(Debug)] /// An implementation of SSL streams pub(crate) struct TlsClientFilter { - inner: IoInner, session: RefCell, } @@ -59,7 +58,7 @@ impl FilterLayer for TlsClientFilter { fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result { let mut session = self.session.borrow_mut(); - let mut new_bytes = usize::from(self.inner.handshake.get()); + let mut new_bytes = 0; // get processed buffer buf.with_src(|src| { @@ -96,7 +95,7 @@ impl FilterLayer for TlsClientFilter { buf.with_src(|src| { if let Some(src) = src { let mut session = self.session.borrow_mut(); - let mut io = Wrapper(&self.inner, buf); + let mut io = Wrapper(buf); loop { if !src.is_empty() { @@ -123,9 +122,6 @@ impl TlsClientFilter { let session = ClientConnection::new(cfg, domain) .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; let filter = TlsFilter::new_client(TlsClientFilter { - inner: IoInner { - handshake: Cell::new(true), - }, session: RefCell::new(session), }); let io = io.add_filter(filter); @@ -134,7 +130,7 @@ impl TlsClientFilter { loop { let (result, wants_read, handshaking) = io.with_buf(|buf| { let mut session = filter.client().session.borrow_mut(); - let mut wrp = Wrapper(&filter.client().inner, buf); + let mut wrp = Wrapper(buf); let mut result = ( session.complete_io(&mut wrp), session.wants_read(), @@ -152,17 +148,15 @@ impl TlsClientFilter { match result { Ok(_) => { - filter.client().inner.handshake.set(false); return Ok(io); } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { if !handshaking { - filter.client().inner.handshake.set(false); return Ok(io); } poll_fn(|cx| { let read_ready = if wants_read { - match ready!(io.poll_read_ready(cx))? { + match ready!(io.poll_force_read_ready(cx))? { Some(_) => Ok(true), None => Err(io::Error::new( io::ErrorKind::Other, diff --git a/ntex-tls/src/rustls/mod.rs b/ntex-tls/src/rustls/mod.rs index 60880508..0e2c89b7 100644 --- a/ntex-tls/src/rustls/mod.rs +++ b/ntex-tls/src/rustls/mod.rs @@ -1,6 +1,6 @@ #![allow(clippy::type_complexity)] //! An implementation of SSL streams for ntex backed by OpenSSL -use std::{any, cell::Cell, cmp, io, sync::Arc, task::Context, task::Poll}; +use std::{any, cmp, io, sync::Arc, task::Context, task::Poll}; use ntex_io::{ Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf, @@ -222,16 +222,11 @@ impl FilterFactory for TlsConnectorConfigured { } } -#[derive(Debug)] -pub(crate) struct IoInner { - handshake: Cell, -} - -pub(crate) struct Wrapper<'a, 'b>(&'a IoInner, &'a WriteBuf<'b>); +pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>); impl<'a, 'b> io::Read for Wrapper<'a, 'b> { fn read(&mut self, dst: &mut [u8]) -> io::Result { - self.1.with_read_buf(|buf| { + self.0.with_read_buf(|buf| { buf.with_src(|buf| { if let Some(buf) = buf { let len = cmp::min(buf.len(), dst.len()); @@ -248,7 +243,7 @@ impl<'a, 'b> io::Read for Wrapper<'a, 'b> { impl<'a, 'b> io::Write for Wrapper<'a, 'b> { fn write(&mut self, src: &[u8]) -> io::Result { - self.1.with_dst(|buf| buf.extend_from_slice(src)); + self.0.with_dst(|buf| buf.extend_from_slice(src)); Ok(src.len()) } diff --git a/ntex-tls/src/rustls/server.rs b/ntex-tls/src/rustls/server.rs index c01437ad..71891c71 100644 --- a/ntex-tls/src/rustls/server.rs +++ b/ntex-tls/src/rustls/server.rs @@ -1,13 +1,13 @@ //! An implementation of SSL streams for ntex backed by OpenSSL use std::io::{self, Read as IoRead, Write as IoWrite}; -use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll}; +use std::{any, cell::RefCell, sync::Arc, task::Poll}; use ntex_bytes::BufMut; use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf}; use ntex_util::{future::poll_fn, ready, time, time::Millis}; use tls_rust::{ServerConfig, ServerConnection}; -use crate::rustls::{IoInner, TlsFilter, Wrapper}; +use crate::rustls::{TlsFilter, Wrapper}; use crate::Servername; use super::{PeerCert, PeerCertChain}; @@ -15,7 +15,6 @@ use super::{PeerCert, PeerCertChain}; #[derive(Debug)] /// An implementation of SSL streams pub(crate) struct TlsServerFilter { - inner: IoInner, session: RefCell, } @@ -66,7 +65,7 @@ impl FilterLayer for TlsServerFilter { fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result { let mut session = self.session.borrow_mut(); - let mut new_bytes = usize::from(self.inner.handshake.get()); + let mut new_bytes = 0; // get processed buffer buf.with_src(|src| { @@ -103,7 +102,7 @@ impl FilterLayer for TlsServerFilter { buf.with_src(|src| { if let Some(src) = src { let mut session = self.session.borrow_mut(); - let mut io = Wrapper(&self.inner, buf); + let mut io = Wrapper(buf); loop { if !src.is_empty() { @@ -132,9 +131,6 @@ impl TlsServerFilter { .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; let filter = TlsFilter::new_server(TlsServerFilter { session: RefCell::new(session), - inner: IoInner { - handshake: Cell::new(true), - }, }); let io = io.add_filter(filter); @@ -142,7 +138,7 @@ impl TlsServerFilter { loop { let (result, wants_read, handshaking) = io.with_buf(|buf| { let mut session = filter.server().session.borrow_mut(); - let mut wrp = Wrapper(&filter.server().inner, buf); + let mut wrp = Wrapper(buf); let mut result = ( session.complete_io(&mut wrp), session.wants_read(), @@ -160,17 +156,15 @@ impl TlsServerFilter { match result { Ok(_) => { - filter.server().inner.handshake.set(false); return Ok(io); } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { if !handshaking { - filter.server().inner.handshake.set(false); return Ok(io); } poll_fn(|cx| { let read_ready = if wants_read { - match ready!(io.poll_read_ready(cx))? { + match ready!(io.poll_force_read_ready(cx))? { Some(_) => Ok(true), None => Err(io::Error::new( io::ErrorKind::Other,