mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-05 13:57:39 +03:00
Update Filter trait usage
This commit is contained in:
parent
1df005f53f
commit
b49a5ed195
6 changed files with 203 additions and 275 deletions
|
@ -1,5 +1,9 @@
|
||||||
# Changes
|
# Changes
|
||||||
|
|
||||||
|
## [0.1.2] - 2022-01-12
|
||||||
|
|
||||||
|
* Update Filter trait usage
|
||||||
|
|
||||||
## [0.1.1] - 2022-01-10
|
## [0.1.1] - 2022-01-10
|
||||||
|
|
||||||
* Remove usage of ntex::io::Boxed types
|
* Remove usage of ntex::io::Boxed types
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "ntex-tls"
|
name = "ntex-tls"
|
||||||
version = "0.1.1"
|
version = "0.1.2"
|
||||||
authors = ["ntex contributors <team@ntex.rs>"]
|
authors = ["ntex contributors <team@ntex.rs>"]
|
||||||
description = "An implementation of SSL streams for ntex backed by OpenSSL"
|
description = "An implementation of SSL streams for ntex backed by OpenSSL"
|
||||||
keywords = ["network", "framework", "async", "futures"]
|
keywords = ["network", "framework", "async", "futures"]
|
||||||
|
@ -26,8 +26,8 @@ rustls = ["tls_rust"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ntex-bytes = "0.1.9"
|
ntex-bytes = "0.1.9"
|
||||||
ntex-io = "0.1.2"
|
ntex-io = "0.1.3"
|
||||||
ntex-util = "0.1.8"
|
ntex-util = "0.1.9"
|
||||||
ntex-service = "0.3.1"
|
ntex-service = "0.3.1"
|
||||||
pin-project-lite = "0.2"
|
pin-project-lite = "0.2"
|
||||||
|
|
||||||
|
|
|
@ -27,25 +27,26 @@ pub struct PeerCertChain(pub Vec<X509>);
|
||||||
/// An implementation of SSL streams
|
/// An implementation of SSL streams
|
||||||
pub struct SslFilter<F = Base> {
|
pub struct SslFilter<F = Base> {
|
||||||
inner: RefCell<SslStream<IoInner<F>>>,
|
inner: RefCell<SslStream<IoInner<F>>>,
|
||||||
|
pool: PoolRef,
|
||||||
handshake: Cell<bool>,
|
handshake: Cell<bool>,
|
||||||
|
read_buf: Cell<Option<BytesMut>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IoInner<F> {
|
struct IoInner<F> {
|
||||||
inner: F,
|
inner: F,
|
||||||
pool: PoolRef,
|
pool: PoolRef,
|
||||||
read_buf: Option<BytesMut>,
|
|
||||||
write_buf: Option<BytesMut>,
|
write_buf: Option<BytesMut>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: Filter> io::Read for IoInner<F> {
|
impl<F: Filter> io::Read for IoInner<F> {
|
||||||
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
||||||
if let Some(ref mut buf) = self.read_buf {
|
if let Some(mut buf) = self.inner.get_read_buf() {
|
||||||
if buf.is_empty() {
|
if buf.is_empty() {
|
||||||
buf.clear();
|
|
||||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||||
} else {
|
} else {
|
||||||
let len = cmp::min(buf.len(), dst.len());
|
let len = cmp::min(buf.len(), dst.len());
|
||||||
dst[..len].copy_from_slice(&buf.split_to(len));
|
dst[..len].copy_from_slice(&buf.split_to(len));
|
||||||
|
self.inner.release_read_buf(buf);
|
||||||
Ok(len)
|
Ok(len)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -139,70 +140,54 @@ impl<F: Filter> Filter for SslFilter<F> {
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||||
if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() {
|
self.read_buf.take()
|
||||||
if !buf.is_empty() {
|
|
||||||
return Some(buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn get_write_buf(&self) -> Option<BytesMut> {
|
fn get_write_buf(&self) -> Option<BytesMut> {
|
||||||
if let Some(buf) = self.inner.borrow_mut().get_mut().write_buf.take() {
|
self.inner.borrow_mut().get_mut().write_buf.take()
|
||||||
if !buf.is_empty() {
|
|
||||||
return Some(buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn release_read_buf(
|
#[inline]
|
||||||
&self,
|
fn release_read_buf(&self, buf: BytesMut) {
|
||||||
io: &IoRef,
|
self.read_buf.set(Some(buf));
|
||||||
src: BytesMut,
|
}
|
||||||
dst: &mut Option<BytesMut>,
|
|
||||||
nbytes: usize,
|
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
|
||||||
) -> io::Result<usize> {
|
// ask inner filter to process read buf
|
||||||
// store to read_buf
|
match self
|
||||||
let pool = {
|
.inner
|
||||||
let mut inner = self.inner.borrow_mut();
|
.borrow_mut()
|
||||||
let mut dst = None;
|
|
||||||
let result = inner
|
|
||||||
.get_ref()
|
.get_ref()
|
||||||
.inner
|
.inner
|
||||||
.release_read_buf(io, src, &mut dst, nbytes);
|
.process_read_buf(io, nbytes)
|
||||||
if let Err(err) = result {
|
{
|
||||||
io.want_shutdown(Some(err));
|
Err(err) => io.want_shutdown(Some(err)),
|
||||||
|
Ok((n, 0)) => return Ok((n, 0)),
|
||||||
|
Ok((_, _)) => (),
|
||||||
}
|
}
|
||||||
if dst.is_some() {
|
|
||||||
inner.get_mut().read_buf = dst;
|
|
||||||
inner.get_ref().pool
|
|
||||||
} else {
|
|
||||||
return Ok(0);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let (hw, lw) = pool.read_params().unpack();
|
|
||||||
|
|
||||||
// get inner filter buffer
|
// get processed buffer
|
||||||
if dst.is_none() {
|
let mut dst = if let Some(dst) = self.get_read_buf() {
|
||||||
*dst = Some(pool.get_read_buf());
|
dst
|
||||||
}
|
} else {
|
||||||
let buf = dst.as_mut().unwrap();
|
self.pool.get_read_buf()
|
||||||
|
};
|
||||||
|
let (hw, lw) = self.pool.read_params().unpack();
|
||||||
|
|
||||||
let mut new_bytes = 0;
|
let mut new_bytes = 0;
|
||||||
loop {
|
loop {
|
||||||
// make sure we've got room
|
// make sure we've got room
|
||||||
let remaining = buf.remaining_mut();
|
let remaining = dst.remaining_mut();
|
||||||
if remaining < lw {
|
if remaining < lw {
|
||||||
buf.reserve(hw - remaining);
|
dst.reserve(hw - remaining);
|
||||||
}
|
}
|
||||||
|
|
||||||
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
|
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
|
||||||
let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
|
let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
|
||||||
return match ssl_result {
|
let result = match ssl_result {
|
||||||
Ok(v) => {
|
Ok(v) => {
|
||||||
unsafe { buf.advance_mut(v) };
|
unsafe { dst.advance_mut(v) };
|
||||||
new_bytes += v;
|
new_bytes += v;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -216,14 +201,16 @@ impl<F: Filter> Filter for SslFilter<F> {
|
||||||
self.handshake.set(false);
|
self.handshake.set(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(new_bytes)
|
Ok((dst.len(), new_bytes))
|
||||||
}
|
}
|
||||||
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
|
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
|
||||||
io.want_shutdown(None);
|
io.want_shutdown(None);
|
||||||
Ok(new_bytes)
|
Ok((dst.len(), new_bytes))
|
||||||
}
|
}
|
||||||
Err(e) => Err(map_to_ioerr(e)),
|
Err(e) => Err(map_to_ioerr(e)),
|
||||||
};
|
};
|
||||||
|
self.release_read_buf(dst);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,18 +286,18 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
|
||||||
let ssl = ctx_result.map_err(map_to_ioerr)?;
|
let ssl = ctx_result.map_err(map_to_ioerr)?;
|
||||||
let pool = st.memory_pool();
|
let pool = st.memory_pool();
|
||||||
let st = st.map_filter(|inner: F| {
|
let st = st.map_filter(|inner: F| {
|
||||||
let read_buf = inner.get_read_buf();
|
|
||||||
let inner = IoInner {
|
let inner = IoInner {
|
||||||
pool,
|
pool,
|
||||||
inner,
|
inner,
|
||||||
read_buf,
|
|
||||||
write_buf: None,
|
write_buf: None,
|
||||||
};
|
};
|
||||||
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
|
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
|
||||||
|
|
||||||
Ok::<_, Box<dyn Error>>(SslFilter {
|
Ok::<_, Box<dyn Error>>(SslFilter {
|
||||||
inner: RefCell::new(ssl_stream),
|
pool,
|
||||||
|
read_buf: Cell::new(None),
|
||||||
handshake: Cell::new(true),
|
handshake: Cell::new(true),
|
||||||
|
inner: RefCell::new(ssl_stream),
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
@ -352,18 +339,18 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
|
||||||
let ssl = self.ssl;
|
let ssl = self.ssl;
|
||||||
let pool = st.memory_pool();
|
let pool = st.memory_pool();
|
||||||
let st = st.map_filter(|inner: F| {
|
let st = st.map_filter(|inner: F| {
|
||||||
let read_buf = inner.get_read_buf();
|
|
||||||
let inner = IoInner {
|
let inner = IoInner {
|
||||||
pool,
|
pool,
|
||||||
inner,
|
inner,
|
||||||
read_buf,
|
|
||||||
write_buf: None,
|
write_buf: None,
|
||||||
};
|
};
|
||||||
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
|
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
|
||||||
|
|
||||||
Ok::<_, Box<dyn Error>>(SslFilter {
|
Ok::<_, Box<dyn Error>>(SslFilter {
|
||||||
inner: RefCell::new(ssl_stream),
|
pool,
|
||||||
|
read_buf: Cell::new(None),
|
||||||
handshake: Cell::new(true),
|
handshake: Cell::new(true),
|
||||||
|
inner: RefCell::new(ssl_stream),
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||||
use std::io::{self, Read as IoRead, Write as IoWrite};
|
use std::io::{self, Read as IoRead, Write as IoWrite};
|
||||||
use std::{any, cell::RefCell, cmp, sync::Arc, task::Context, task::Poll};
|
use std::{any, cell::RefCell, sync::Arc, task::Context, task::Poll};
|
||||||
|
|
||||||
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
use ntex_bytes::{BufMut, BytesMut};
|
||||||
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
|
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
|
||||||
use ntex_util::{future::poll_fn, ready};
|
use ntex_util::{future::poll_fn, ready};
|
||||||
use tls_rust::{ClientConfig, ClientConnection, ServerName};
|
use tls_rust::{ClientConfig, ClientConnection, ServerName};
|
||||||
|
|
||||||
use super::TlsFilter;
|
use crate::rustls::{IoInner, TlsFilter, Wrapper};
|
||||||
use crate::types;
|
use crate::types;
|
||||||
|
|
||||||
/// An implementation of SSL streams
|
/// An implementation of SSL streams
|
||||||
|
@ -16,13 +16,6 @@ pub struct TlsClientFilter<F> {
|
||||||
session: RefCell<ClientConnection>,
|
session: RefCell<ClientConnection>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IoInner<F> {
|
|
||||||
inner: F,
|
|
||||||
pool: PoolRef,
|
|
||||||
read_buf: Option<BytesMut>,
|
|
||||||
write_buf: Option<BytesMut>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: Filter> Filter for TlsClientFilter<F> {
|
impl<F: Filter> Filter for TlsClientFilter<F> {
|
||||||
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
|
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
|
||||||
const H2: &[u8] = b"h2";
|
const H2: &[u8] = b"h2";
|
||||||
|
@ -42,85 +35,74 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
|
||||||
};
|
};
|
||||||
Some(Box::new(proto))
|
Some(Box::new(proto))
|
||||||
} else {
|
} else {
|
||||||
self.inner.borrow().inner.query(id)
|
self.inner.borrow().filter.query(id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
|
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
|
||||||
self.inner.borrow().inner.poll_shutdown()
|
self.inner.borrow().filter.poll_shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
|
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
|
||||||
self.inner.borrow().inner.poll_read_ready(cx)
|
self.inner.borrow().filter.poll_read_ready(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
|
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
|
||||||
self.inner.borrow().inner.poll_write_ready(cx)
|
self.inner.borrow().filter.poll_write_ready(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||||
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
|
self.inner.borrow_mut().read_buf.take()
|
||||||
if !buf.is_empty() {
|
|
||||||
return Some(buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn get_write_buf(&self) -> Option<BytesMut> {
|
fn get_write_buf(&self) -> Option<BytesMut> {
|
||||||
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
|
self.inner.borrow_mut().write_buf.take()
|
||||||
if !buf.is_empty() {
|
|
||||||
return Some(buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn release_read_buf(
|
#[inline]
|
||||||
&self,
|
fn release_read_buf(&self, buf: BytesMut) {
|
||||||
io: &IoRef,
|
self.inner.borrow_mut().read_buf = Some(buf);
|
||||||
src: BytesMut,
|
}
|
||||||
dst: &mut Option<BytesMut>,
|
|
||||||
nbytes: usize,
|
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
|
||||||
) -> io::Result<usize> {
|
|
||||||
let mut inner = self.inner.borrow_mut();
|
let mut inner = self.inner.borrow_mut();
|
||||||
let mut session = self.session.borrow_mut();
|
let mut session = self.session.borrow_mut();
|
||||||
|
|
||||||
if session.is_handshaking() {
|
// ask inner filter to process read buf
|
||||||
inner.read_buf = Some(src);
|
match inner.filter.process_read_buf(io, nbytes) {
|
||||||
Ok(1)
|
Err(err) => io.want_shutdown(Some(err)),
|
||||||
} else {
|
Ok((_, 0)) => return Ok((0, 0)),
|
||||||
let mut src = {
|
Ok(_) => (),
|
||||||
let mut dst = None;
|
|
||||||
if let Err(err) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
|
|
||||||
io.want_shutdown(Some(err));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(dst) = dst {
|
if session.is_handshaking() {
|
||||||
|
Ok((0, 1))
|
||||||
|
} else {
|
||||||
|
// get processed buffer
|
||||||
|
let mut dst = if let Some(dst) = inner.read_buf.take() {
|
||||||
dst
|
dst
|
||||||
} else {
|
} else {
|
||||||
return Ok(0);
|
inner.pool.get_read_buf()
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let (hw, lw) = inner.pool.read_params().unpack();
|
let (hw, lw) = inner.pool.read_params().unpack();
|
||||||
|
|
||||||
// get inner filter buffer
|
let mut src = if let Some(src) = inner.filter.get_read_buf() {
|
||||||
if dst.is_none() {
|
src
|
||||||
*dst = Some(inner.pool.get_read_buf());
|
} else {
|
||||||
}
|
return Ok((0, 0));
|
||||||
let buf = dst.as_mut().unwrap();
|
};
|
||||||
|
|
||||||
let mut new_bytes = 0;
|
let mut new_bytes = 0;
|
||||||
loop {
|
loop {
|
||||||
// make sure we've got room
|
// make sure we've got room
|
||||||
let remaining = buf.remaining_mut();
|
let remaining = dst.remaining_mut();
|
||||||
if remaining < lw {
|
if remaining < lw {
|
||||||
buf.reserve(hw - remaining);
|
dst.reserve(hw - remaining);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut cursor = io::Cursor::new(&src);
|
let mut cursor = io::Cursor::new(&src);
|
||||||
|
@ -132,21 +114,21 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
|
||||||
|
|
||||||
let new_b = state.plaintext_bytes_to_read();
|
let new_b = state.plaintext_bytes_to_read();
|
||||||
if new_b > 0 {
|
if new_b > 0 {
|
||||||
buf.reserve(new_b);
|
dst.reserve(new_b);
|
||||||
let chunk: &mut [u8] =
|
let chunk: &mut [u8] =
|
||||||
unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
|
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
|
||||||
let v = session.reader().read(chunk)?;
|
let v = session.reader().read(chunk)?;
|
||||||
unsafe { buf.advance_mut(v) };
|
unsafe { dst.advance_mut(v) };
|
||||||
new_bytes += v;
|
new_bytes += v;
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !src.is_empty() {
|
let dst_len = dst.len();
|
||||||
inner.read_buf = Some(src);
|
inner.read_buf = Some(dst);
|
||||||
}
|
inner.filter.release_read_buf(src);
|
||||||
Ok(new_bytes)
|
Ok((dst_len, new_bytes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,42 +157,6 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Wrapper<'a, F>(&'a mut IoInner<F>);
|
|
||||||
|
|
||||||
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
|
|
||||||
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
|
||||||
if let Some(read_buf) = self.0.read_buf.as_mut() {
|
|
||||||
let len = cmp::min(read_buf.len(), dst.len());
|
|
||||||
if len > 0 {
|
|
||||||
dst[..len].copy_from_slice(&read_buf.split_to(len));
|
|
||||||
Ok(len)
|
|
||||||
} else {
|
|
||||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
|
|
||||||
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
|
|
||||||
let mut buf = if let Some(mut buf) = self.0.inner.get_write_buf() {
|
|
||||||
buf.reserve(src.len());
|
|
||||||
buf
|
|
||||||
} else {
|
|
||||||
BytesMut::with_capacity_in(src.len(), self.0.pool)
|
|
||||||
};
|
|
||||||
buf.extend_from_slice(src);
|
|
||||||
self.0.inner.release_write_buf(buf)?;
|
|
||||||
Ok(src.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: Filter> TlsClientFilter<F> {
|
impl<F: Filter> TlsClientFilter<F> {
|
||||||
pub(crate) async fn create(
|
pub(crate) async fn create(
|
||||||
io: Io<F>,
|
io: Io<F>,
|
||||||
|
@ -222,12 +168,11 @@ impl<F: Filter> TlsClientFilter<F> {
|
||||||
Ok(session) => session,
|
Ok(session) => session,
|
||||||
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
||||||
};
|
};
|
||||||
let io = io.map_filter(|inner: F| {
|
let io = io.map_filter(|filter: F| {
|
||||||
let read_buf = inner.get_read_buf();
|
|
||||||
let inner = IoInner {
|
let inner = IoInner {
|
||||||
pool,
|
pool,
|
||||||
inner,
|
filter,
|
||||||
read_buf,
|
read_buf: None,
|
||||||
write_buf: None,
|
write_buf: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
#![allow(clippy::type_complexity)]
|
#![allow(clippy::type_complexity)]
|
||||||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll};
|
use std::{any, cmp, future::Future, io, pin::Pin, task::Context, task::Poll};
|
||||||
|
|
||||||
use ntex_bytes::BytesMut;
|
use ntex_bytes::{BytesMut, PoolRef};
|
||||||
use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
|
use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
|
||||||
use ntex_util::time::Millis;
|
use ntex_util::time::Millis;
|
||||||
use tls_rust::{ClientConfig, ServerConfig, ServerName};
|
use tls_rust::{ClientConfig, ServerConfig, ServerName};
|
||||||
|
@ -101,16 +101,18 @@ impl<F: Filter> Filter for TlsFilter<F> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn release_read_buf(
|
fn release_read_buf(&self, buf: BytesMut) {
|
||||||
&self,
|
|
||||||
io: &IoRef,
|
|
||||||
src: BytesMut,
|
|
||||||
dst: &mut Option<BytesMut>,
|
|
||||||
nb: usize,
|
|
||||||
) -> io::Result<usize> {
|
|
||||||
match self.inner {
|
match self.inner {
|
||||||
InnerTlsFilter::Server(ref f) => f.release_read_buf(io, src, dst, nb),
|
InnerTlsFilter::Server(ref f) => f.release_read_buf(buf),
|
||||||
InnerTlsFilter::Client(ref f) => f.release_read_buf(io, src, dst, nb),
|
InnerTlsFilter::Client(ref f) => f.release_read_buf(buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn process_read_buf(&self, io: &IoRef, nb: usize) -> io::Result<(usize, usize)> {
|
||||||
|
match self.inner {
|
||||||
|
InnerTlsFilter::Server(ref f) => f.process_read_buf(io, nb),
|
||||||
|
InnerTlsFilter::Client(ref f) => f.process_read_buf(io, nb),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,3 +231,48 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
|
||||||
Box::pin(async move { TlsClientFilter::create(st, cfg, server_name).await })
|
Box::pin(async move { TlsClientFilter::create(st, cfg, server_name).await })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) struct IoInner<F> {
|
||||||
|
filter: F,
|
||||||
|
pool: PoolRef,
|
||||||
|
read_buf: Option<BytesMut>,
|
||||||
|
write_buf: Option<BytesMut>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct Wrapper<'a, F>(&'a mut IoInner<F>);
|
||||||
|
|
||||||
|
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
|
||||||
|
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
||||||
|
if let Some(mut read_buf) = self.0.filter.get_read_buf() {
|
||||||
|
let len = cmp::min(read_buf.len(), dst.len());
|
||||||
|
let result = if len > 0 {
|
||||||
|
dst[..len].copy_from_slice(&read_buf.split_to(len));
|
||||||
|
Ok(len)
|
||||||
|
} else {
|
||||||
|
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
||||||
|
};
|
||||||
|
self.0.filter.release_read_buf(read_buf);
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
|
||||||
|
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
|
||||||
|
let mut buf = if let Some(mut buf) = self.0.filter.get_write_buf() {
|
||||||
|
buf.reserve(src.len());
|
||||||
|
buf
|
||||||
|
} else {
|
||||||
|
BytesMut::with_capacity_in(src.len(), self.0.pool)
|
||||||
|
};
|
||||||
|
buf.extend_from_slice(src);
|
||||||
|
self.0.filter.release_write_buf(buf)?;
|
||||||
|
Ok(src.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> io::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||||
use std::io::{self, Read as IoRead, Write as IoWrite};
|
use std::io::{self, Read as IoRead, Write as IoWrite};
|
||||||
use std::sync::Arc;
|
use std::{any, cell::RefCell, sync::Arc, task::Context, task::Poll};
|
||||||
use std::{any, cell::RefCell, cmp, task::Context, task::Poll};
|
|
||||||
|
|
||||||
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
use ntex_bytes::{BufMut, BytesMut};
|
||||||
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
|
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
|
||||||
use ntex_util::{future::poll_fn, ready, time, time::Millis};
|
use ntex_util::{future::poll_fn, ready, time, time::Millis};
|
||||||
use tls_rust::{ServerConfig, ServerConnection};
|
use tls_rust::{ServerConfig, ServerConnection};
|
||||||
|
|
||||||
use crate::{rustls::TlsFilter, types};
|
use crate::rustls::{IoInner, TlsFilter, Wrapper};
|
||||||
|
use crate::types;
|
||||||
|
|
||||||
/// An implementation of SSL streams
|
/// An implementation of SSL streams
|
||||||
pub struct TlsServerFilter<F> {
|
pub struct TlsServerFilter<F> {
|
||||||
|
@ -16,13 +16,6 @@ pub struct TlsServerFilter<F> {
|
||||||
session: RefCell<ServerConnection>,
|
session: RefCell<ServerConnection>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct IoInner<F> {
|
|
||||||
inner: F,
|
|
||||||
pool: PoolRef,
|
|
||||||
read_buf: Option<BytesMut>,
|
|
||||||
write_buf: Option<BytesMut>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: Filter> Filter for TlsServerFilter<F> {
|
impl<F: Filter> Filter for TlsServerFilter<F> {
|
||||||
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
|
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
|
||||||
const H2: &[u8] = b"h2";
|
const H2: &[u8] = b"h2";
|
||||||
|
@ -42,85 +35,74 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
||||||
};
|
};
|
||||||
Some(Box::new(proto))
|
Some(Box::new(proto))
|
||||||
} else {
|
} else {
|
||||||
self.inner.borrow().inner.query(id)
|
self.inner.borrow().filter.query(id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
|
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
|
||||||
self.inner.borrow().inner.poll_shutdown()
|
self.inner.borrow().filter.poll_shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
|
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
|
||||||
self.inner.borrow().inner.poll_read_ready(cx)
|
self.inner.borrow().filter.poll_read_ready(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
|
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
|
||||||
self.inner.borrow().inner.poll_write_ready(cx)
|
self.inner.borrow().filter.poll_write_ready(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||||
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
|
self.inner.borrow_mut().read_buf.take()
|
||||||
if !buf.is_empty() {
|
|
||||||
return Some(buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn get_write_buf(&self) -> Option<BytesMut> {
|
fn get_write_buf(&self) -> Option<BytesMut> {
|
||||||
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
|
self.inner.borrow_mut().write_buf.take()
|
||||||
if !buf.is_empty() {
|
|
||||||
return Some(buf);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn release_read_buf(
|
#[inline]
|
||||||
&self,
|
fn release_read_buf(&self, buf: BytesMut) {
|
||||||
io: &IoRef,
|
self.inner.borrow_mut().read_buf = Some(buf);
|
||||||
src: BytesMut,
|
}
|
||||||
dst: &mut Option<BytesMut>,
|
|
||||||
nbytes: usize,
|
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
|
||||||
) -> io::Result<usize> {
|
|
||||||
let mut inner = self.inner.borrow_mut();
|
let mut inner = self.inner.borrow_mut();
|
||||||
let mut session = self.session.borrow_mut();
|
let mut session = self.session.borrow_mut();
|
||||||
|
|
||||||
if session.is_handshaking() {
|
// ask inner filter to process read buf
|
||||||
inner.read_buf = Some(src);
|
match inner.filter.process_read_buf(io, nbytes) {
|
||||||
Ok(1)
|
Err(err) => io.want_shutdown(Some(err)),
|
||||||
} else {
|
Ok((_, 0)) => return Ok((0, 0)),
|
||||||
let mut src = {
|
Ok(_) => (),
|
||||||
let mut dst = None;
|
|
||||||
if let Err(e) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
|
|
||||||
io.want_shutdown(Some(e));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(dst) = dst {
|
if session.is_handshaking() {
|
||||||
|
Ok((0, 1))
|
||||||
|
} else {
|
||||||
|
// get processed buffer
|
||||||
|
let mut dst = if let Some(dst) = inner.read_buf.take() {
|
||||||
dst
|
dst
|
||||||
} else {
|
} else {
|
||||||
return Ok(0);
|
inner.pool.get_read_buf()
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let (hw, lw) = inner.pool.read_params().unpack();
|
let (hw, lw) = inner.pool.read_params().unpack();
|
||||||
|
|
||||||
// get inner filter buffer
|
let mut src = if let Some(src) = inner.filter.get_read_buf() {
|
||||||
if dst.is_none() {
|
src
|
||||||
*dst = Some(inner.pool.get_read_buf());
|
} else {
|
||||||
}
|
return Ok((0, 0));
|
||||||
let buf = dst.as_mut().unwrap();
|
};
|
||||||
|
|
||||||
let mut new_bytes = 0;
|
let mut new_bytes = 0;
|
||||||
loop {
|
loop {
|
||||||
// make sure we've got room
|
// make sure we've got room
|
||||||
let remaining = buf.remaining_mut();
|
let remaining = dst.remaining_mut();
|
||||||
if remaining < lw {
|
if remaining < lw {
|
||||||
buf.reserve(hw - remaining);
|
dst.reserve(hw - remaining);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut cursor = io::Cursor::new(&src);
|
let mut cursor = io::Cursor::new(&src);
|
||||||
|
@ -132,21 +114,21 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
||||||
|
|
||||||
let new_b = state.plaintext_bytes_to_read();
|
let new_b = state.plaintext_bytes_to_read();
|
||||||
if new_b > 0 {
|
if new_b > 0 {
|
||||||
buf.reserve(new_b);
|
dst.reserve(new_b);
|
||||||
let chunk: &mut [u8] =
|
let chunk: &mut [u8] =
|
||||||
unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
|
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
|
||||||
let v = session.reader().read(chunk)?;
|
let v = session.reader().read(chunk)?;
|
||||||
unsafe { buf.advance_mut(v) };
|
unsafe { dst.advance_mut(v) };
|
||||||
new_bytes += v;
|
new_bytes += v;
|
||||||
} else {
|
} else {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !src.is_empty() {
|
let dst_len = dst.len();
|
||||||
inner.read_buf = Some(src);
|
inner.read_buf = Some(dst);
|
||||||
}
|
inner.filter.release_read_buf(src);
|
||||||
Ok(new_bytes)
|
Ok((dst_len, new_bytes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,42 +156,6 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Wrapper<'a, F>(&'a mut IoInner<F>);
|
|
||||||
|
|
||||||
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
|
|
||||||
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
|
||||||
if let Some(read_buf) = self.0.read_buf.as_mut() {
|
|
||||||
let len = cmp::min(read_buf.len(), dst.len());
|
|
||||||
if len > 0 {
|
|
||||||
dst[..len].copy_from_slice(&read_buf.split_to(len));
|
|
||||||
Ok(len)
|
|
||||||
} else {
|
|
||||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
|
|
||||||
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
|
|
||||||
let mut buf = if let Some(mut buf) = self.0.inner.get_write_buf() {
|
|
||||||
buf.reserve(src.len());
|
|
||||||
buf
|
|
||||||
} else {
|
|
||||||
BytesMut::with_capacity_in(src.len(), self.0.pool)
|
|
||||||
};
|
|
||||||
buf.extend_from_slice(src);
|
|
||||||
self.0.inner.release_write_buf(buf)?;
|
|
||||||
Ok(src.len())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F: Filter> TlsServerFilter<F> {
|
impl<F: Filter> TlsServerFilter<F> {
|
||||||
pub(crate) async fn create(
|
pub(crate) async fn create(
|
||||||
io: Io<F>,
|
io: Io<F>,
|
||||||
|
@ -222,12 +168,11 @@ impl<F: Filter> TlsServerFilter<F> {
|
||||||
Ok(session) => session,
|
Ok(session) => session,
|
||||||
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
||||||
};
|
};
|
||||||
let io = io.map_filter(|inner: F| {
|
let io = io.map_filter(|filter: F| {
|
||||||
let read_buf = inner.get_read_buf();
|
|
||||||
let inner = IoInner {
|
let inner = IoInner {
|
||||||
pool,
|
pool,
|
||||||
inner,
|
filter,
|
||||||
read_buf,
|
read_buf: None,
|
||||||
write_buf: None,
|
write_buf: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue