Update Filter trait usage

This commit is contained in:
Nikolay Kim 2022-01-12 22:08:50 +06:00
parent 1df005f53f
commit b49a5ed195
6 changed files with 203 additions and 275 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.1.2] - 2022-01-12
* Update Filter trait usage
## [0.1.1] - 2022-01-10
* Remove usage of ntex::io::Boxed types

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "0.1.1"
version = "0.1.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
@ -26,8 +26,8 @@ rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1.9"
ntex-io = "0.1.2"
ntex-util = "0.1.8"
ntex-io = "0.1.3"
ntex-util = "0.1.9"
ntex-service = "0.3.1"
pin-project-lite = "0.2"

View file

@ -27,25 +27,26 @@ pub struct PeerCertChain(pub Vec<X509>);
/// An implementation of SSL streams
pub struct SslFilter<F = Base> {
inner: RefCell<SslStream<IoInner<F>>>,
pool: PoolRef,
handshake: Cell<bool>,
read_buf: Cell<Option<BytesMut>>,
}
struct IoInner<F> {
inner: F,
pool: PoolRef,
read_buf: Option<BytesMut>,
write_buf: Option<BytesMut>,
}
impl<F: Filter> io::Read for IoInner<F> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(ref mut buf) = self.read_buf {
if let Some(mut buf) = self.inner.get_read_buf() {
if buf.is_empty() {
buf.clear();
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let len = cmp::min(buf.len(), dst.len());
dst[..len].copy_from_slice(&buf.split_to(len));
self.inner.release_read_buf(buf);
Ok(len)
}
} else {
@ -139,70 +140,54 @@ impl<F: Filter> Filter for SslFilter<F> {
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
self.read_buf.take()
}
#[inline]
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
self.inner.borrow_mut().get_mut().write_buf.take()
}
fn release_read_buf(
&self,
io: &IoRef,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> io::Result<usize> {
// store to read_buf
let pool = {
let mut inner = self.inner.borrow_mut();
let mut dst = None;
let result = inner
.get_ref()
.inner
.release_read_buf(io, src, &mut dst, nbytes);
if let Err(err) = result {
io.want_shutdown(Some(err));
}
if dst.is_some() {
inner.get_mut().read_buf = dst;
inner.get_ref().pool
} else {
return Ok(0);
}
};
let (hw, lw) = pool.read_params().unpack();
#[inline]
fn release_read_buf(&self, buf: BytesMut) {
self.read_buf.set(Some(buf));
}
// get inner filter buffer
if dst.is_none() {
*dst = Some(pool.get_read_buf());
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
// ask inner filter to process read buf
match self
.inner
.borrow_mut()
.get_ref()
.inner
.process_read_buf(io, nbytes)
{
Err(err) => io.want_shutdown(Some(err)),
Ok((n, 0)) => return Ok((n, 0)),
Ok((_, _)) => (),
}
let buf = dst.as_mut().unwrap();
// get processed buffer
let mut dst = if let Some(dst) = self.get_read_buf() {
dst
} else {
self.pool.get_read_buf()
};
let (hw, lw) = self.pool.read_params().unpack();
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
let remaining = dst.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
dst.reserve(hw - remaining);
}
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
return match ssl_result {
let result = match ssl_result {
Ok(v) => {
unsafe { buf.advance_mut(v) };
unsafe { dst.advance_mut(v) };
new_bytes += v;
continue;
}
@ -216,14 +201,16 @@ impl<F: Filter> Filter for SslFilter<F> {
self.handshake.set(false);
}
}
Ok(new_bytes)
Ok((dst.len(), new_bytes))
}
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
io.want_shutdown(None);
Ok(new_bytes)
Ok((dst.len(), new_bytes))
}
Err(e) => Err(map_to_ioerr(e)),
};
self.release_read_buf(dst);
return result;
}
}
@ -299,18 +286,18 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
let ssl = ctx_result.map_err(map_to_ioerr)?;
let pool = st.memory_pool();
let st = st.map_filter(|inner: F| {
let read_buf = inner.get_read_buf();
let inner = IoInner {
pool,
inner,
read_buf,
write_buf: None,
};
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
Ok::<_, Box<dyn Error>>(SslFilter {
inner: RefCell::new(ssl_stream),
pool,
read_buf: Cell::new(None),
handshake: Cell::new(true),
inner: RefCell::new(ssl_stream),
})
})?;
@ -352,18 +339,18 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
let ssl = self.ssl;
let pool = st.memory_pool();
let st = st.map_filter(|inner: F| {
let read_buf = inner.get_read_buf();
let inner = IoInner {
pool,
inner,
read_buf,
write_buf: None,
};
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
Ok::<_, Box<dyn Error>>(SslFilter {
inner: RefCell::new(ssl_stream),
pool,
read_buf: Cell::new(None),
handshake: Cell::new(true),
inner: RefCell::new(ssl_stream),
})
})?;

View file

@ -1,13 +1,13 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite};
use std::{any, cell::RefCell, cmp, sync::Arc, task::Context, task::Poll};
use std::{any, cell::RefCell, sync::Arc, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_bytes::{BufMut, BytesMut};
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::{future::poll_fn, ready};
use tls_rust::{ClientConfig, ClientConnection, ServerName};
use super::TlsFilter;
use crate::rustls::{IoInner, TlsFilter, Wrapper};
use crate::types;
/// An implementation of SSL streams
@ -16,13 +16,6 @@ pub struct TlsClientFilter<F> {
session: RefCell<ClientConnection>,
}
struct IoInner<F> {
inner: F,
pool: PoolRef,
read_buf: Option<BytesMut>,
write_buf: Option<BytesMut>,
}
impl<F: Filter> Filter for TlsClientFilter<F> {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2";
@ -42,85 +35,74 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
};
Some(Box::new(proto))
} else {
self.inner.borrow().inner.query(id)
self.inner.borrow().filter.query(id)
}
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.borrow().inner.poll_shutdown()
self.inner.borrow().filter.poll_shutdown()
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.borrow().inner.poll_read_ready(cx)
self.inner.borrow().filter.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.borrow().inner.poll_write_ready(cx)
self.inner.borrow().filter.poll_write_ready(cx)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
self.inner.borrow_mut().read_buf.take()
}
#[inline]
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
self.inner.borrow_mut().write_buf.take()
}
fn release_read_buf(
&self,
io: &IoRef,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> io::Result<usize> {
#[inline]
fn release_read_buf(&self, buf: BytesMut) {
self.inner.borrow_mut().read_buf = Some(buf);
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
let mut inner = self.inner.borrow_mut();
let mut session = self.session.borrow_mut();
if session.is_handshaking() {
inner.read_buf = Some(src);
Ok(1)
} else {
let mut src = {
let mut dst = None;
if let Err(err) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
io.want_shutdown(Some(err));
}
// ask inner filter to process read buf
match inner.filter.process_read_buf(io, nbytes) {
Err(err) => io.want_shutdown(Some(err)),
Ok((_, 0)) => return Ok((0, 0)),
Ok(_) => (),
}
if let Some(dst) = dst {
dst
} else {
return Ok(0);
}
if session.is_handshaking() {
Ok((0, 1))
} else {
// get processed buffer
let mut dst = if let Some(dst) = inner.read_buf.take() {
dst
} else {
inner.pool.get_read_buf()
};
let (hw, lw) = inner.pool.read_params().unpack();
// get inner filter buffer
if dst.is_none() {
*dst = Some(inner.pool.get_read_buf());
}
let buf = dst.as_mut().unwrap();
let mut src = if let Some(src) = inner.filter.get_read_buf() {
src
} else {
return Ok((0, 0));
};
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
let remaining = dst.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
dst.reserve(hw - remaining);
}
let mut cursor = io::Cursor::new(&src);
@ -132,21 +114,21 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
let new_b = state.plaintext_bytes_to_read();
if new_b > 0 {
buf.reserve(new_b);
dst.reserve(new_b);
let chunk: &mut [u8] =
unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let v = session.reader().read(chunk)?;
unsafe { buf.advance_mut(v) };
unsafe { dst.advance_mut(v) };
new_bytes += v;
} else {
break;
}
}
if !src.is_empty() {
inner.read_buf = Some(src);
}
Ok(new_bytes)
let dst_len = dst.len();
inner.read_buf = Some(dst);
inner.filter.release_read_buf(src);
Ok((dst_len, new_bytes))
}
}
@ -175,42 +157,6 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
}
}
struct Wrapper<'a, F>(&'a mut IoInner<F>);
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(read_buf) = self.0.read_buf.as_mut() {
let len = cmp::min(read_buf.len(), dst.len());
if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
}
}
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.0.inner.get_write_buf() {
buf.reserve(src.len());
buf
} else {
BytesMut::with_capacity_in(src.len(), self.0.pool)
};
buf.extend_from_slice(src);
self.0.inner.release_write_buf(buf)?;
Ok(src.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<F: Filter> TlsClientFilter<F> {
pub(crate) async fn create(
io: Io<F>,
@ -222,12 +168,11 @@ impl<F: Filter> TlsClientFilter<F> {
Ok(session) => session,
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
};
let io = io.map_filter(|inner: F| {
let read_buf = inner.get_read_buf();
let io = io.map_filter(|filter: F| {
let inner = IoInner {
pool,
inner,
read_buf,
filter,
read_buf: None,
write_buf: None,
};

View file

@ -1,9 +1,9 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::sync::Arc;
use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll};
use std::{any, cmp, future::Future, io, pin::Pin, task::Context, task::Poll};
use ntex_bytes::BytesMut;
use ntex_bytes::{BytesMut, PoolRef};
use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::time::Millis;
use tls_rust::{ClientConfig, ServerConfig, ServerName};
@ -101,16 +101,18 @@ impl<F: Filter> Filter for TlsFilter<F> {
}
#[inline]
fn release_read_buf(
&self,
io: &IoRef,
src: BytesMut,
dst: &mut Option<BytesMut>,
nb: usize,
) -> io::Result<usize> {
fn release_read_buf(&self, buf: BytesMut) {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_read_buf(io, src, dst, nb),
InnerTlsFilter::Client(ref f) => f.release_read_buf(io, src, dst, nb),
InnerTlsFilter::Server(ref f) => f.release_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.release_read_buf(buf),
}
}
#[inline]
fn process_read_buf(&self, io: &IoRef, nb: usize) -> io::Result<(usize, usize)> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_read_buf(io, nb),
InnerTlsFilter::Client(ref f) => f.process_read_buf(io, nb),
}
}
@ -229,3 +231,48 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
Box::pin(async move { TlsClientFilter::create(st, cfg, server_name).await })
}
}
pub(crate) struct IoInner<F> {
filter: F,
pool: PoolRef,
read_buf: Option<BytesMut>,
write_buf: Option<BytesMut>,
}
pub(crate) struct Wrapper<'a, F>(&'a mut IoInner<F>);
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut read_buf) = self.0.filter.get_read_buf() {
let len = cmp::min(read_buf.len(), dst.len());
let result = if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
};
self.0.filter.release_read_buf(read_buf);
result
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
}
}
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.0.filter.get_write_buf() {
buf.reserve(src.len());
buf
} else {
BytesMut::with_capacity_in(src.len(), self.0.pool)
};
buf.extend_from_slice(src);
self.0.filter.release_write_buf(buf)?;
Ok(src.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}

View file

@ -1,14 +1,14 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite};
use std::sync::Arc;
use std::{any, cell::RefCell, cmp, task::Context, task::Poll};
use std::{any, cell::RefCell, sync::Arc, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_bytes::{BufMut, BytesMut};
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::{future::poll_fn, ready, time, time::Millis};
use tls_rust::{ServerConfig, ServerConnection};
use crate::{rustls::TlsFilter, types};
use crate::rustls::{IoInner, TlsFilter, Wrapper};
use crate::types;
/// An implementation of SSL streams
pub struct TlsServerFilter<F> {
@ -16,13 +16,6 @@ pub struct TlsServerFilter<F> {
session: RefCell<ServerConnection>,
}
struct IoInner<F> {
inner: F,
pool: PoolRef,
read_buf: Option<BytesMut>,
write_buf: Option<BytesMut>,
}
impl<F: Filter> Filter for TlsServerFilter<F> {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2";
@ -42,85 +35,74 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
};
Some(Box::new(proto))
} else {
self.inner.borrow().inner.query(id)
self.inner.borrow().filter.query(id)
}
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.borrow().inner.poll_shutdown()
self.inner.borrow().filter.poll_shutdown()
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.borrow().inner.poll_read_ready(cx)
self.inner.borrow().filter.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.borrow().inner.poll_write_ready(cx)
self.inner.borrow().filter.poll_write_ready(cx)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
self.inner.borrow_mut().read_buf.take()
}
#[inline]
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
self.inner.borrow_mut().write_buf.take()
}
fn release_read_buf(
&self,
io: &IoRef,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> io::Result<usize> {
#[inline]
fn release_read_buf(&self, buf: BytesMut) {
self.inner.borrow_mut().read_buf = Some(buf);
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
let mut inner = self.inner.borrow_mut();
let mut session = self.session.borrow_mut();
if session.is_handshaking() {
inner.read_buf = Some(src);
Ok(1)
} else {
let mut src = {
let mut dst = None;
if let Err(e) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
io.want_shutdown(Some(e));
}
// ask inner filter to process read buf
match inner.filter.process_read_buf(io, nbytes) {
Err(err) => io.want_shutdown(Some(err)),
Ok((_, 0)) => return Ok((0, 0)),
Ok(_) => (),
}
if let Some(dst) = dst {
dst
} else {
return Ok(0);
}
if session.is_handshaking() {
Ok((0, 1))
} else {
// get processed buffer
let mut dst = if let Some(dst) = inner.read_buf.take() {
dst
} else {
inner.pool.get_read_buf()
};
let (hw, lw) = inner.pool.read_params().unpack();
// get inner filter buffer
if dst.is_none() {
*dst = Some(inner.pool.get_read_buf());
}
let buf = dst.as_mut().unwrap();
let mut src = if let Some(src) = inner.filter.get_read_buf() {
src
} else {
return Ok((0, 0));
};
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
let remaining = dst.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
dst.reserve(hw - remaining);
}
let mut cursor = io::Cursor::new(&src);
@ -132,21 +114,21 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
let new_b = state.plaintext_bytes_to_read();
if new_b > 0 {
buf.reserve(new_b);
dst.reserve(new_b);
let chunk: &mut [u8] =
unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let v = session.reader().read(chunk)?;
unsafe { buf.advance_mut(v) };
unsafe { dst.advance_mut(v) };
new_bytes += v;
} else {
break;
}
}
if !src.is_empty() {
inner.read_buf = Some(src);
}
Ok(new_bytes)
let dst_len = dst.len();
inner.read_buf = Some(dst);
inner.filter.release_read_buf(src);
Ok((dst_len, new_bytes))
}
}
@ -174,42 +156,6 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
}
}
struct Wrapper<'a, F>(&'a mut IoInner<F>);
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(read_buf) = self.0.read_buf.as_mut() {
let len = cmp::min(read_buf.len(), dst.len());
if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
}
}
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.0.inner.get_write_buf() {
buf.reserve(src.len());
buf
} else {
BytesMut::with_capacity_in(src.len(), self.0.pool)
};
buf.extend_from_slice(src);
self.0.inner.release_write_buf(buf)?;
Ok(src.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<F: Filter> TlsServerFilter<F> {
pub(crate) async fn create(
io: Io<F>,
@ -222,12 +168,11 @@ impl<F: Filter> TlsServerFilter<F> {
Ok(session) => session,
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
};
let io = io.map_filter(|inner: F| {
let read_buf = inner.get_read_buf();
let io = io.map_filter(|filter: F| {
let inner = IoInner {
pool,
inner,
read_buf,
filter,
read_buf: None,
write_buf: None,
};