mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-03 21:07:39 +03:00
Refactor tls impl (#237)
This commit is contained in:
parent
d460d9c259
commit
24ff5d6909
9 changed files with 114 additions and 97 deletions
|
@ -1,8 +1,8 @@
|
|||
# Changes
|
||||
|
||||
## [0.3.4] - 2023-11-xx
|
||||
|
||||
## [0.3.4] - 2023-11-03
|
||||
|
||||
* Add Io::force_ready_ready() and Io::poll_force_ready_ready() methods
|
||||
|
||||
## [0.3.3] - 2023-09-11
|
||||
|
||||
|
|
|
@ -31,6 +31,8 @@ bitflags::bitflags! {
|
|||
const RD_READY = 0b0000_0000_0010_0000;
|
||||
/// read buffer is full
|
||||
const RD_BUF_FULL = 0b0000_0000_0100_0000;
|
||||
/// any new data is available
|
||||
const RD_FORCE_READY = 0b0000_0000_1000_0000;
|
||||
|
||||
/// wait write completion
|
||||
const WR_WAIT = 0b0000_0001_0000_0000;
|
||||
|
@ -78,10 +80,15 @@ impl IoState {
|
|||
self.flags.set(flags);
|
||||
}
|
||||
|
||||
pub(super) fn remove_flags(&self, f: Flags) {
|
||||
pub(super) fn remove_flags(&self, f: Flags) -> bool {
|
||||
let mut flags = self.flags.get();
|
||||
flags.remove(f);
|
||||
self.flags.set(flags);
|
||||
if flags.intersects(f) {
|
||||
flags.remove(f);
|
||||
self.flags.set(flags);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn notify_keepalive(&self) {
|
||||
|
@ -365,6 +372,13 @@ impl<F> Io<F> {
|
|||
poll_fn(|cx| self.poll_read_ready(cx)).await
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[inline]
|
||||
/// Wait until read becomes ready.
|
||||
pub async fn force_read_ready(&self) -> io::Result<Option<()>> {
|
||||
poll_fn(|cx| self.poll_force_read_ready(cx)).await
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Pause read task
|
||||
pub fn pause(&self) {
|
||||
|
@ -455,6 +469,39 @@ impl<F> Io<F> {
|
|||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[inline]
|
||||
/// Polls for read readiness.
|
||||
///
|
||||
/// If the io stream is not currently ready for reading,
|
||||
/// this method will store a clone of the Waker from the provided Context.
|
||||
/// When the io stream becomes ready for reading, Waker::wake will be called on the waker.
|
||||
///
|
||||
/// Return value
|
||||
/// The function returns:
|
||||
///
|
||||
/// `Poll::Pending` if the io stream is not ready for reading.
|
||||
/// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading.
|
||||
/// `Poll::Ready(Ok(None))` if io stream is disconnected
|
||||
/// `Some(Poll::Ready(Err(e)))` if an error is encountered.
|
||||
pub fn poll_force_read_ready(
|
||||
&self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<Option<()>>> {
|
||||
let ready = self.poll_read_ready(cx);
|
||||
|
||||
if ready.is_pending() {
|
||||
if self.0 .0.remove_flags(Flags::RD_FORCE_READY) {
|
||||
Poll::Ready(Ok(Some(())))
|
||||
} else {
|
||||
self.0 .0.insert_flags(Flags::RD_FORCE_READY);
|
||||
Poll::Pending
|
||||
}
|
||||
} else {
|
||||
ready
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Decode codec item from incoming bytes stream.
|
||||
///
|
||||
|
|
|
@ -61,6 +61,10 @@ impl ReadContext {
|
|||
// so we need to wake up read task to read more data
|
||||
// otherwise read task would sleep forever
|
||||
inner.read_task.wake();
|
||||
} else if inner.flags.get().contains(Flags::RD_FORCE_READY) {
|
||||
// in case of "force read" we must wake up dispatch task
|
||||
// if we read any data from source
|
||||
inner.dispatch_task.wake();
|
||||
}
|
||||
|
||||
// while reading, filter wrote some data
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
# Changes
|
||||
|
||||
## [0.3.2] - 2023-11-03
|
||||
|
||||
* Improve implementation
|
||||
|
||||
## [0.3.1] - 2023-09-11
|
||||
|
||||
* Add missing fmt::Debug impls
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex-tls"
|
||||
version = "0.3.1"
|
||||
version = "0.3.2"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "An implementation of SSL streams for ntex backed by OpenSSL"
|
||||
keywords = ["network", "framework", "async", "futures"]
|
||||
|
@ -26,9 +26,9 @@ rustls = ["tls_rust"]
|
|||
|
||||
[dependencies]
|
||||
ntex-bytes = "0.1.19"
|
||||
ntex-io = "0.3.3"
|
||||
ntex-util = "0.3.2"
|
||||
ntex-service = "1.2.6"
|
||||
ntex-io = "0.3.4"
|
||||
ntex-util = "0.3.3"
|
||||
ntex-service = "1.2.7"
|
||||
log = "0.4"
|
||||
pin-project-lite = "0.2"
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::{any, cmp, error::Error, fmt, io, task::Context, task::Poll};
|
||||
use std::cell::RefCell;
|
||||
use std::{any, cmp, error::Error, fmt, io, task::Poll};
|
||||
|
||||
use ntex_bytes::{BufMut, BytesVec};
|
||||
use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
|
||||
use ntex_util::{future::poll_fn, future::BoxFuture, ready, time, time::Millis};
|
||||
use ntex_util::{future::BoxFuture, time, time::Millis};
|
||||
use tls_openssl::ssl::{self, NameType, SslStream};
|
||||
use tls_openssl::x509::X509;
|
||||
|
||||
|
@ -25,7 +25,6 @@ pub struct PeerCertChain(pub Vec<X509>);
|
|||
#[derive(Debug)]
|
||||
pub struct SslFilter {
|
||||
inner: RefCell<SslStream<IoInner>>,
|
||||
handshake: Cell<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -147,7 +146,7 @@ impl FilterLayer for SslFilter {
|
|||
buf.with_write_buf(|b| {
|
||||
self.with_buffers(b, || {
|
||||
buf.with_dst(|dst| {
|
||||
let mut new_bytes = usize::from(self.handshake.get());
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
buf.resize_buf(dst);
|
||||
|
||||
|
@ -270,27 +269,21 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
|
|||
destination: None,
|
||||
};
|
||||
let filter = SslFilter {
|
||||
handshake: Cell::new(true),
|
||||
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
|
||||
};
|
||||
let io = io.add_filter(filter);
|
||||
|
||||
poll_fn(|cx| {
|
||||
let result = io
|
||||
.with_buf(|buf| {
|
||||
let filter = io.filter();
|
||||
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
|
||||
})
|
||||
.map_err(|err| {
|
||||
let err: Box<dyn Error> =
|
||||
io::Error::new(io::ErrorKind::Other, err).into();
|
||||
err
|
||||
})?;
|
||||
handle_result(result, &io, cx)
|
||||
})
|
||||
.await?;
|
||||
log::debug!("Accepting tls connection");
|
||||
loop {
|
||||
let result = io.with_buf(|buf| {
|
||||
let filter = io.filter();
|
||||
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
|
||||
})?;
|
||||
if handle_result(&io, result).await?.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
io.filter().handshake.set(false);
|
||||
Ok(io)
|
||||
})
|
||||
.await
|
||||
|
@ -327,55 +320,41 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
|
|||
destination: None,
|
||||
};
|
||||
let filter = SslFilter {
|
||||
handshake: Cell::new(true),
|
||||
inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?),
|
||||
};
|
||||
let io = io.add_filter(filter);
|
||||
|
||||
poll_fn(|cx| {
|
||||
let result = io
|
||||
.with_buf(|buf| {
|
||||
let filter = io.filter();
|
||||
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
|
||||
})
|
||||
.map_err(|err| {
|
||||
let err: Box<dyn Error> =
|
||||
io::Error::new(io::ErrorKind::Other, err).into();
|
||||
err
|
||||
})?;
|
||||
handle_result(result, &io, cx)
|
||||
})
|
||||
.await?;
|
||||
loop {
|
||||
let result = io.with_buf(|buf| {
|
||||
let filter = io.filter();
|
||||
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
|
||||
})?;
|
||||
if handle_result(&io, result).await?.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
io.filter().handshake.set(false);
|
||||
Ok(io)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_result<T, F>(
|
||||
result: Result<T, ssl::Error>,
|
||||
async fn handle_result<T, F>(
|
||||
io: &Io<F>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<T, Box<dyn Error>>> {
|
||||
result: Result<T, ssl::Error>,
|
||||
) -> io::Result<Option<T>> {
|
||||
match result {
|
||||
Ok(v) => Poll::Ready(Ok(v)),
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(e) => match e.code() {
|
||||
ssl::ErrorCode::WANT_READ => {
|
||||
match ready!(io.poll_read_ready(cx)) {
|
||||
Ok(None) => Err::<_, Box<dyn Error>>(
|
||||
io::Error::new(io::ErrorKind::Other, "disconnected").into(),
|
||||
),
|
||||
Err(err) => Err(err.into()),
|
||||
_ => Ok(()),
|
||||
}?;
|
||||
Poll::Pending
|
||||
let res = io.force_read_ready().await;
|
||||
match res? {
|
||||
None => Err(io::Error::new(io::ErrorKind::Other, "disconnected")),
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
ssl::ErrorCode::WANT_WRITE => {
|
||||
let _ = io.poll_flush(cx, true)?;
|
||||
Poll::Pending
|
||||
}
|
||||
_ => Poll::Ready(Err(Box::new(e))),
|
||||
ssl::ErrorCode::WANT_WRITE => Ok(None),
|
||||
_ => Err(io::Error::new(io::ErrorKind::Other, e)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,20 +1,19 @@
|
|||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::io::{self, Read as IoRead, Write as IoWrite};
|
||||
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll};
|
||||
use std::{any, cell::RefCell, sync::Arc, task::Poll};
|
||||
|
||||
use ntex_bytes::BufMut;
|
||||
use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
|
||||
use ntex_util::{future::poll_fn, ready};
|
||||
use tls_rust::{ClientConfig, ClientConnection, ServerName};
|
||||
|
||||
use crate::rustls::{IoInner, TlsFilter, Wrapper};
|
||||
use crate::rustls::{TlsFilter, Wrapper};
|
||||
|
||||
use super::{PeerCert, PeerCertChain};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// An implementation of SSL streams
|
||||
pub(crate) struct TlsClientFilter {
|
||||
inner: IoInner,
|
||||
session: RefCell<ClientConnection>,
|
||||
}
|
||||
|
||||
|
@ -59,7 +58,7 @@ impl FilterLayer for TlsClientFilter {
|
|||
|
||||
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
|
||||
let mut session = self.session.borrow_mut();
|
||||
let mut new_bytes = usize::from(self.inner.handshake.get());
|
||||
let mut new_bytes = 0;
|
||||
|
||||
// get processed buffer
|
||||
buf.with_src(|src| {
|
||||
|
@ -96,7 +95,7 @@ impl FilterLayer for TlsClientFilter {
|
|||
buf.with_src(|src| {
|
||||
if let Some(src) = src {
|
||||
let mut session = self.session.borrow_mut();
|
||||
let mut io = Wrapper(&self.inner, buf);
|
||||
let mut io = Wrapper(buf);
|
||||
|
||||
loop {
|
||||
if !src.is_empty() {
|
||||
|
@ -123,9 +122,6 @@ impl TlsClientFilter {
|
|||
let session = ClientConnection::new(cfg, domain)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
|
||||
let filter = TlsFilter::new_client(TlsClientFilter {
|
||||
inner: IoInner {
|
||||
handshake: Cell::new(true),
|
||||
},
|
||||
session: RefCell::new(session),
|
||||
});
|
||||
let io = io.add_filter(filter);
|
||||
|
@ -134,7 +130,7 @@ impl TlsClientFilter {
|
|||
loop {
|
||||
let (result, wants_read, handshaking) = io.with_buf(|buf| {
|
||||
let mut session = filter.client().session.borrow_mut();
|
||||
let mut wrp = Wrapper(&filter.client().inner, buf);
|
||||
let mut wrp = Wrapper(buf);
|
||||
let mut result = (
|
||||
session.complete_io(&mut wrp),
|
||||
session.wants_read(),
|
||||
|
@ -152,17 +148,15 @@ impl TlsClientFilter {
|
|||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
filter.client().inner.handshake.set(false);
|
||||
return Ok(io);
|
||||
}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
if !handshaking {
|
||||
filter.client().inner.handshake.set(false);
|
||||
return Ok(io);
|
||||
}
|
||||
poll_fn(|cx| {
|
||||
let read_ready = if wants_read {
|
||||
match ready!(io.poll_read_ready(cx))? {
|
||||
match ready!(io.poll_force_read_ready(cx))? {
|
||||
Some(_) => Ok(true),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#![allow(clippy::type_complexity)]
|
||||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::{any, cell::Cell, cmp, io, sync::Arc, task::Context, task::Poll};
|
||||
use std::{any, cmp, io, sync::Arc, task::Context, task::Poll};
|
||||
|
||||
use ntex_io::{
|
||||
Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf,
|
||||
|
@ -222,16 +222,11 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct IoInner {
|
||||
handshake: Cell<bool>,
|
||||
}
|
||||
|
||||
pub(crate) struct Wrapper<'a, 'b>(&'a IoInner, &'a WriteBuf<'b>);
|
||||
pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>);
|
||||
|
||||
impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
|
||||
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
||||
self.1.with_read_buf(|buf| {
|
||||
self.0.with_read_buf(|buf| {
|
||||
buf.with_src(|buf| {
|
||||
if let Some(buf) = buf {
|
||||
let len = cmp::min(buf.len(), dst.len());
|
||||
|
@ -248,7 +243,7 @@ impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
|
|||
|
||||
impl<'a, 'b> io::Write for Wrapper<'a, 'b> {
|
||||
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
|
||||
self.1.with_dst(|buf| buf.extend_from_slice(src));
|
||||
self.0.with_dst(|buf| buf.extend_from_slice(src));
|
||||
Ok(src.len())
|
||||
}
|
||||
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::io::{self, Read as IoRead, Write as IoWrite};
|
||||
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll};
|
||||
use std::{any, cell::RefCell, sync::Arc, task::Poll};
|
||||
|
||||
use ntex_bytes::BufMut;
|
||||
use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
|
||||
use ntex_util::{future::poll_fn, ready, time, time::Millis};
|
||||
use tls_rust::{ServerConfig, ServerConnection};
|
||||
|
||||
use crate::rustls::{IoInner, TlsFilter, Wrapper};
|
||||
use crate::rustls::{TlsFilter, Wrapper};
|
||||
use crate::Servername;
|
||||
|
||||
use super::{PeerCert, PeerCertChain};
|
||||
|
@ -15,7 +15,6 @@ use super::{PeerCert, PeerCertChain};
|
|||
#[derive(Debug)]
|
||||
/// An implementation of SSL streams
|
||||
pub(crate) struct TlsServerFilter {
|
||||
inner: IoInner,
|
||||
session: RefCell<ServerConnection>,
|
||||
}
|
||||
|
||||
|
@ -66,7 +65,7 @@ impl FilterLayer for TlsServerFilter {
|
|||
|
||||
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
|
||||
let mut session = self.session.borrow_mut();
|
||||
let mut new_bytes = usize::from(self.inner.handshake.get());
|
||||
let mut new_bytes = 0;
|
||||
|
||||
// get processed buffer
|
||||
buf.with_src(|src| {
|
||||
|
@ -103,7 +102,7 @@ impl FilterLayer for TlsServerFilter {
|
|||
buf.with_src(|src| {
|
||||
if let Some(src) = src {
|
||||
let mut session = self.session.borrow_mut();
|
||||
let mut io = Wrapper(&self.inner, buf);
|
||||
let mut io = Wrapper(buf);
|
||||
|
||||
loop {
|
||||
if !src.is_empty() {
|
||||
|
@ -132,9 +131,6 @@ impl TlsServerFilter {
|
|||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
|
||||
let filter = TlsFilter::new_server(TlsServerFilter {
|
||||
session: RefCell::new(session),
|
||||
inner: IoInner {
|
||||
handshake: Cell::new(true),
|
||||
},
|
||||
});
|
||||
let io = io.add_filter(filter);
|
||||
|
||||
|
@ -142,7 +138,7 @@ impl TlsServerFilter {
|
|||
loop {
|
||||
let (result, wants_read, handshaking) = io.with_buf(|buf| {
|
||||
let mut session = filter.server().session.borrow_mut();
|
||||
let mut wrp = Wrapper(&filter.server().inner, buf);
|
||||
let mut wrp = Wrapper(buf);
|
||||
let mut result = (
|
||||
session.complete_io(&mut wrp),
|
||||
session.wants_read(),
|
||||
|
@ -160,17 +156,15 @@ impl TlsServerFilter {
|
|||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
filter.server().inner.handshake.set(false);
|
||||
return Ok(io);
|
||||
}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
if !handshaking {
|
||||
filter.server().inner.handshake.set(false);
|
||||
return Ok(io);
|
||||
}
|
||||
poll_fn(|cx| {
|
||||
let read_ready = if wants_read {
|
||||
match ready!(io.poll_read_ready(cx))? {
|
||||
match ready!(io.poll_force_read_ready(cx))? {
|
||||
Some(_) => Ok(true),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue