Refactor tls impl (#237)

This commit is contained in:
Nikolay Kim 2023-11-03 17:33:45 +06:00 committed by GitHub
parent d460d9c259
commit 24ff5d6909
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 114 additions and 97 deletions

View file

@ -1,8 +1,8 @@
# Changes
## [0.3.4] - 2023-11-xx
## [0.3.4] - 2023-11-03
* Add Io::force_ready_ready() and Io::poll_force_ready_ready() methods
## [0.3.3] - 2023-09-11

View file

@ -31,6 +31,8 @@ bitflags::bitflags! {
const RD_READY = 0b0000_0000_0010_0000;
/// read buffer is full
const RD_BUF_FULL = 0b0000_0000_0100_0000;
/// any new data is available
const RD_FORCE_READY = 0b0000_0000_1000_0000;
/// wait write completion
const WR_WAIT = 0b0000_0001_0000_0000;
@ -78,10 +80,15 @@ impl IoState {
self.flags.set(flags);
}
pub(super) fn remove_flags(&self, f: Flags) {
pub(super) fn remove_flags(&self, f: Flags) -> bool {
let mut flags = self.flags.get();
flags.remove(f);
self.flags.set(flags);
if flags.intersects(f) {
flags.remove(f);
self.flags.set(flags);
true
} else {
false
}
}
pub(super) fn notify_keepalive(&self) {
@ -365,6 +372,13 @@ impl<F> Io<F> {
poll_fn(|cx| self.poll_read_ready(cx)).await
}
#[doc(hidden)]
#[inline]
/// Wait until read becomes ready.
pub async fn force_read_ready(&self) -> io::Result<Option<()>> {
poll_fn(|cx| self.poll_force_read_ready(cx)).await
}
#[inline]
/// Pause read task
pub fn pause(&self) {
@ -455,6 +469,39 @@ impl<F> Io<F> {
}
}
#[doc(hidden)]
#[inline]
/// Polls for read readiness.
///
/// If the io stream is not currently ready for reading,
/// this method will store a clone of the Waker from the provided Context.
/// When the io stream becomes ready for reading, Waker::wake will be called on the waker.
///
/// Return value
/// The function returns:
///
/// `Poll::Pending` if the io stream is not ready for reading.
/// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading.
/// `Poll::Ready(Ok(None))` if io stream is disconnected
/// `Some(Poll::Ready(Err(e)))` if an error is encountered.
pub fn poll_force_read_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<()>>> {
let ready = self.poll_read_ready(cx);
if ready.is_pending() {
if self.0 .0.remove_flags(Flags::RD_FORCE_READY) {
Poll::Ready(Ok(Some(())))
} else {
self.0 .0.insert_flags(Flags::RD_FORCE_READY);
Poll::Pending
}
} else {
ready
}
}
#[inline]
/// Decode codec item from incoming bytes stream.
///

View file

@ -61,6 +61,10 @@ impl ReadContext {
// so we need to wake up read task to read more data
// otherwise read task would sleep forever
inner.read_task.wake();
} else if inner.flags.get().contains(Flags::RD_FORCE_READY) {
// in case of "force read" we must wake up dispatch task
// if we read any data from source
inner.dispatch_task.wake();
}
// while reading, filter wrote some data

View file

@ -1,5 +1,9 @@
# Changes
## [0.3.2] - 2023-11-03
* Improve implementation
## [0.3.1] - 2023-09-11
* Add missing fmt::Debug impls

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "0.3.1"
version = "0.3.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
@ -26,9 +26,9 @@ rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1.19"
ntex-io = "0.3.3"
ntex-util = "0.3.2"
ntex-service = "1.2.6"
ntex-io = "0.3.4"
ntex-util = "0.3.3"
ntex-service = "1.2.7"
log = "0.4"
pin-project-lite = "0.2"

View file

@ -1,10 +1,10 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::cell::{Cell, RefCell};
use std::{any, cmp, error::Error, fmt, io, task::Context, task::Poll};
use std::cell::RefCell;
use std::{any, cmp, error::Error, fmt, io, task::Poll};
use ntex_bytes::{BufMut, BytesVec};
use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, future::BoxFuture, ready, time, time::Millis};
use ntex_util::{future::BoxFuture, time, time::Millis};
use tls_openssl::ssl::{self, NameType, SslStream};
use tls_openssl::x509::X509;
@ -25,7 +25,6 @@ pub struct PeerCertChain(pub Vec<X509>);
#[derive(Debug)]
pub struct SslFilter {
inner: RefCell<SslStream<IoInner>>,
handshake: Cell<bool>,
}
#[derive(Debug)]
@ -147,7 +146,7 @@ impl FilterLayer for SslFilter {
buf.with_write_buf(|b| {
self.with_buffers(b, || {
buf.with_dst(|dst| {
let mut new_bytes = usize::from(self.handshake.get());
let mut new_bytes = 0;
loop {
buf.resize_buf(dst);
@ -270,27 +269,21 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
destination: None,
};
let filter = SslFilter {
handshake: Cell::new(true),
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
let io = io.add_filter(filter);
poll_fn(|cx| {
let result = io
.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
})
.map_err(|err| {
let err: Box<dyn Error> =
io::Error::new(io::ErrorKind::Other, err).into();
err
})?;
handle_result(result, &io, cx)
})
.await?;
log::debug!("Accepting tls connection");
loop {
let result = io.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
})?;
if handle_result(&io, result).await?.is_some() {
break;
}
}
io.filter().handshake.set(false);
Ok(io)
})
.await
@ -327,55 +320,41 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
destination: None,
};
let filter = SslFilter {
handshake: Cell::new(true),
inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?),
};
let io = io.add_filter(filter);
poll_fn(|cx| {
let result = io
.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
})
.map_err(|err| {
let err: Box<dyn Error> =
io::Error::new(io::ErrorKind::Other, err).into();
err
})?;
handle_result(result, &io, cx)
})
.await?;
loop {
let result = io.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
})?;
if handle_result(&io, result).await?.is_some() {
break;
}
}
io.filter().handshake.set(false);
Ok(io)
})
}
}
fn handle_result<T, F>(
result: Result<T, ssl::Error>,
async fn handle_result<T, F>(
io: &Io<F>,
cx: &mut Context<'_>,
) -> Poll<Result<T, Box<dyn Error>>> {
result: Result<T, ssl::Error>,
) -> io::Result<Option<T>> {
match result {
Ok(v) => Poll::Ready(Ok(v)),
Ok(v) => Ok(Some(v)),
Err(e) => match e.code() {
ssl::ErrorCode::WANT_READ => {
match ready!(io.poll_read_ready(cx)) {
Ok(None) => Err::<_, Box<dyn Error>>(
io::Error::new(io::ErrorKind::Other, "disconnected").into(),
),
Err(err) => Err(err.into()),
_ => Ok(()),
}?;
Poll::Pending
let res = io.force_read_ready().await;
match res? {
None => Err(io::Error::new(io::ErrorKind::Other, "disconnected")),
_ => Ok(None),
}
}
ssl::ErrorCode::WANT_WRITE => {
let _ = io.poll_flush(cx, true)?;
Poll::Pending
}
_ => Poll::Ready(Err(Box::new(e))),
ssl::ErrorCode::WANT_WRITE => Ok(None),
_ => Err(io::Error::new(io::ErrorKind::Other, e)),
},
}
}

View file

@ -1,20 +1,19 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite};
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll};
use std::{any, cell::RefCell, sync::Arc, task::Poll};
use ntex_bytes::BufMut;
use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, ready};
use tls_rust::{ClientConfig, ClientConnection, ServerName};
use crate::rustls::{IoInner, TlsFilter, Wrapper};
use crate::rustls::{TlsFilter, Wrapper};
use super::{PeerCert, PeerCertChain};
#[derive(Debug)]
/// An implementation of SSL streams
pub(crate) struct TlsClientFilter {
inner: IoInner,
session: RefCell<ClientConnection>,
}
@ -59,7 +58,7 @@ impl FilterLayer for TlsClientFilter {
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
let mut session = self.session.borrow_mut();
let mut new_bytes = usize::from(self.inner.handshake.get());
let mut new_bytes = 0;
// get processed buffer
buf.with_src(|src| {
@ -96,7 +95,7 @@ impl FilterLayer for TlsClientFilter {
buf.with_src(|src| {
if let Some(src) = src {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
let mut io = Wrapper(buf);
loop {
if !src.is_empty() {
@ -123,9 +122,6 @@ impl TlsClientFilter {
let session = ClientConnection::new(cfg, domain)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let filter = TlsFilter::new_client(TlsClientFilter {
inner: IoInner {
handshake: Cell::new(true),
},
session: RefCell::new(session),
});
let io = io.add_filter(filter);
@ -134,7 +130,7 @@ impl TlsClientFilter {
loop {
let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.client().session.borrow_mut();
let mut wrp = Wrapper(&filter.client().inner, buf);
let mut wrp = Wrapper(buf);
let mut result = (
session.complete_io(&mut wrp),
session.wants_read(),
@ -152,17 +148,15 @@ impl TlsClientFilter {
match result {
Ok(_) => {
filter.client().inner.handshake.set(false);
return Ok(io);
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if !handshaking {
filter.client().inner.handshake.set(false);
return Ok(io);
}
poll_fn(|cx| {
let read_ready = if wants_read {
match ready!(io.poll_read_ready(cx))? {
match ready!(io.poll_force_read_ready(cx))? {
Some(_) => Ok(true),
None => Err(io::Error::new(
io::ErrorKind::Other,

View file

@ -1,6 +1,6 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::{any, cell::Cell, cmp, io, sync::Arc, task::Context, task::Poll};
use std::{any, cmp, io, sync::Arc, task::Context, task::Poll};
use ntex_io::{
Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf,
@ -222,16 +222,11 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
}
}
#[derive(Debug)]
pub(crate) struct IoInner {
handshake: Cell<bool>,
}
pub(crate) struct Wrapper<'a, 'b>(&'a IoInner, &'a WriteBuf<'b>);
pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>);
impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
self.1.with_read_buf(|buf| {
self.0.with_read_buf(|buf| {
buf.with_src(|buf| {
if let Some(buf) = buf {
let len = cmp::min(buf.len(), dst.len());
@ -248,7 +243,7 @@ impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
impl<'a, 'b> io::Write for Wrapper<'a, 'b> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
self.1.with_dst(|buf| buf.extend_from_slice(src));
self.0.with_dst(|buf| buf.extend_from_slice(src));
Ok(src.len())
}

View file

@ -1,13 +1,13 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite};
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll};
use std::{any, cell::RefCell, sync::Arc, task::Poll};
use ntex_bytes::BufMut;
use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, ready, time, time::Millis};
use tls_rust::{ServerConfig, ServerConnection};
use crate::rustls::{IoInner, TlsFilter, Wrapper};
use crate::rustls::{TlsFilter, Wrapper};
use crate::Servername;
use super::{PeerCert, PeerCertChain};
@ -15,7 +15,6 @@ use super::{PeerCert, PeerCertChain};
#[derive(Debug)]
/// An implementation of SSL streams
pub(crate) struct TlsServerFilter {
inner: IoInner,
session: RefCell<ServerConnection>,
}
@ -66,7 +65,7 @@ impl FilterLayer for TlsServerFilter {
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
let mut session = self.session.borrow_mut();
let mut new_bytes = usize::from(self.inner.handshake.get());
let mut new_bytes = 0;
// get processed buffer
buf.with_src(|src| {
@ -103,7 +102,7 @@ impl FilterLayer for TlsServerFilter {
buf.with_src(|src| {
if let Some(src) = src {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
let mut io = Wrapper(buf);
loop {
if !src.is_empty() {
@ -132,9 +131,6 @@ impl TlsServerFilter {
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let filter = TlsFilter::new_server(TlsServerFilter {
session: RefCell::new(session),
inner: IoInner {
handshake: Cell::new(true),
},
});
let io = io.add_filter(filter);
@ -142,7 +138,7 @@ impl TlsServerFilter {
loop {
let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.server().session.borrow_mut();
let mut wrp = Wrapper(&filter.server().inner, buf);
let mut wrp = Wrapper(buf);
let mut result = (
session.complete_io(&mut wrp),
session.wants_read(),
@ -160,17 +156,15 @@ impl TlsServerFilter {
match result {
Ok(_) => {
filter.server().inner.handshake.set(false);
return Ok(io);
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if !handshaking {
filter.server().inner.handshake.set(false);
return Ok(io);
}
poll_fn(|cx| {
let read_ready = if wants_read {
match ready!(io.poll_read_ready(cx))? {
match ready!(io.poll_force_read_ready(cx))? {
Some(_) => Ok(true),
None => Err(io::Error::new(
io::ErrorKind::Other,