Refactor io buffers api (#169)

* Refactor io buffers api
This commit is contained in:
Nikolay Kim 2023-01-29 23:02:12 +06:00 committed by GitHub
parent 3b6cf6a3ef
commit 0f8387c3ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 573 additions and 553 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.2.4] - 2023-01-29
* Update buffer api
## [0.2.3] - 2023-01-25
* Fix double buf cleanup

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "0.2.3"
version = "0.2.4"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
@ -26,7 +26,7 @@ rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1.19"
ntex-io = "0.2.3"
ntex-io = "0.2.7"
ntex-util = "0.2.0"
ntex-service = "1.0.0"
log = "0.4"
@ -39,7 +39,7 @@ tls_openssl = { version="0.10", package = "openssl", optional = true }
tls_rust = { version = "0.20", package = "rustls", optional = true }
[dev-dependencies]
ntex = { version = "0.6.1", features = ["openssl", "rustls", "tokio"] }
ntex = { version = "0.6.3", features = ["openssl", "rustls", "tokio"] }
env_logger = "0.10"
rustls-pemfile = { version = "1.0" }
webpki-roots = { version = "0.22" }

View file

@ -2,7 +2,7 @@
use std::cell::{Cell, RefCell};
use std::{any, cmp, error::Error, io, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesVec, PoolRef};
use ntex_bytes::{BufMut, BytesVec};
use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, future::BoxFuture, ready, time, time::Millis};
use tls_openssl::ssl::{self, NameType, SslStream};
@ -24,25 +24,22 @@ pub struct PeerCertChain(pub Vec<X509>);
/// An implementation of SSL streams
pub struct SslFilter {
inner: RefCell<SslStream<IoInner>>,
pool: PoolRef,
handshake: Cell<bool>,
}
struct IoInner {
source: Option<BytesVec>,
destination: Option<BytesVec>,
pool: PoolRef,
}
impl io::Read for IoInner {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut buf) = self.source.take() {
if let Some(ref mut buf) = self.source {
if buf.is_empty() {
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let len = cmp::min(buf.len(), dst.len());
dst[..len].copy_from_slice(&buf.split_to(len));
self.source = Some(buf);
Ok(len)
}
} else {
@ -53,13 +50,7 @@ impl io::Read for IoInner {
impl io::Write for IoInner {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(buf) = self.destination.take() {
buf
} else {
BytesVec::with_capacity_in(src.len(), self.pool)
};
buf.extend_from_slice(src);
self.destination = Some(buf);
self.destination.as_mut().unwrap().extend_from_slice(src);
Ok(src.len())
}
@ -69,7 +60,7 @@ impl io::Write for IoInner {
}
impl SslFilter {
fn with_buffers<F, R>(&self, buf: &mut WriteBuf<'_>, f: F) -> R
fn with_buffers<F, R>(&self, buf: &WriteBuf<'_>, f: F) -> R
where
F: FnOnce() -> R,
{
@ -80,16 +71,6 @@ impl SslFilter {
buf.with_read_buf(|b| b.set_src(self.inner.borrow_mut().get_mut().source.take()));
result
}
fn set_buffers(&self, buf: &mut WriteBuf<'_>) {
self.inner.borrow_mut().get_mut().destination = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().source = buf.with_read_buf(|b| b.take_src());
}
fn unset_buffers(&self, buf: &mut WriteBuf<'_>) {
buf.set_dst(self.inner.borrow_mut().get_mut().destination.take());
buf.with_read_buf(|b| b.set_src(self.inner.borrow_mut().get_mut().source.take()));
}
}
impl FilterLayer for SslFilter {
@ -141,7 +122,7 @@ impl FilterLayer for SslFilter {
}
}
fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result<Poll<()>> {
fn shutdown(&self, buf: &WriteBuf<'_>) -> io::Result<Poll<()>> {
let ssl_result = self.with_buffers(buf, || self.inner.borrow_mut().shutdown());
match ssl_result {
@ -160,75 +141,72 @@ impl FilterLayer for SslFilter {
}
}
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
buf.with_write_buf(|b| self.set_buffers(b));
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
buf.with_write_buf(|b| {
self.with_buffers(b, || {
buf.with_dst(|dst| {
let mut new_bytes = usize::from(self.handshake.get());
loop {
buf.resize_buf(dst);
let dst = buf.get_dst();
//let mut new_bytes = usize::from(self.handshake.get());
let mut new_bytes = 1;
loop {
// make sure we've got room
self.pool.resize_read_buf(dst);
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
let result = match ssl_result {
Ok(v) => {
unsafe { dst.advance_mut(v) };
new_bytes += v;
continue;
}
Err(ref e)
if e.code() == ssl::ErrorCode::WANT_READ
|| e.code() == ssl::ErrorCode::WANT_WRITE =>
{
Ok(new_bytes)
}
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
buf.want_shutdown();
Ok(new_bytes)
}
Err(e) => {
log::trace!("SSL Error: {:?}", e);
Err(map_to_ioerr(e))
}
};
buf.with_write_buf(|b| self.unset_buffers(b));
return result;
}
let chunk: &mut [u8] =
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
let result = match ssl_result {
Ok(v) => {
unsafe { dst.advance_mut(v) };
new_bytes += v;
continue;
}
Err(ref e)
if e.code() == ssl::ErrorCode::WANT_READ
|| e.code() == ssl::ErrorCode::WANT_WRITE =>
{
Ok(new_bytes)
}
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
buf.want_shutdown();
Ok(new_bytes)
}
Err(e) => {
log::trace!("SSL Error: {:?}", e);
Err(map_to_ioerr(e))
}
};
return result;
}
})
})
})
}
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
if let Some(mut src) = buf.take_src() {
self.set_buffers(buf);
loop {
if src.is_empty() {
self.unset_buffers(buf);
return Ok(());
}
let ssl_result = self.inner.borrow_mut().ssl_write(&src);
match ssl_result {
Ok(v) => {
src.split_to(v);
continue;
fn process_write_buf(&self, wb: &WriteBuf<'_>) -> io::Result<()> {
wb.with_src(|b| {
if let Some(src) = b {
self.with_buffers(wb, || loop {
if src.is_empty() {
return Ok(());
}
Err(e) => {
buf.set_src(Some(src));
self.unset_buffers(buf);
return match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => {
Ok(())
}
_ => Err(map_to_ioerr(e)),
};
let ssl_result = self.inner.borrow_mut().ssl_write(src);
match ssl_result {
Ok(v) => {
src.split_to(v);
continue;
}
Err(e) => {
return match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => {
Ok(())
}
_ => Err(map_to_ioerr(e)),
};
}
}
}
})
} else {
Ok(())
}
} else {
Ok(())
}
})
}
}
@ -278,12 +256,10 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
time::timeout(timeout, async {
let ssl = ctx_result.map_err(map_to_ioerr)?;
let inner = IoInner {
pool: io.memory_pool(),
source: None,
destination: None,
};
let filter = SslFilter {
pool: io.memory_pool(),
handshake: Cell::new(true),
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
@ -336,12 +312,10 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
fn create(self, io: Io<F>) -> Self::Future {
Box::pin(async move {
let inner = IoInner {
pool: io.memory_pool(),
source: None,
destination: None,
};
let filter = SslFilter {
pool: io.memory_pool(),
handshake: Cell::new(true),
inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?),
};

View file

@ -56,62 +56,60 @@ impl FilterLayer for TlsClientFilter {
}
}
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
let mut session = self.session.borrow_mut();
let mut new_bytes = usize::from(self.inner.handshake.get());
// get processed buffer
let (src, dst) = buf.get_pair();
let mut new_bytes = usize::from(self.inner.handshake.get());
loop {
// make sure we've got room
self.inner.pool.resize_read_buf(dst);
buf.with_src(|src| {
if let Some(src) = src {
buf.with_dst(|dst| {
loop {
let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?;
src.split_to(n);
let state = session
.process_new_packets()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?;
src.split_to(n);
let state = session
.process_new_packets()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let new_b = state.plaintext_bytes_to_read();
if new_b > 0 {
dst.reserve(new_b);
let chunk: &mut [u8] =
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let v = session.reader().read(chunk)?;
unsafe { dst.advance_mut(v) };
new_bytes += v;
} else {
break;
let new_b = state.plaintext_bytes_to_read();
if new_b > 0 {
dst.reserve(new_b);
let chunk: &mut [u8] =
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let v = session.reader().read(chunk)?;
unsafe { dst.advance_mut(v) };
new_bytes += v;
} else {
break;
}
}
Ok::<_, io::Error>(())
})?;
}
}
Ok(new_bytes)
Ok(new_bytes)
})
}
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
if let Some(mut src) = buf.take_src() {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
buf.with_src(|src| {
if let Some(src) = src {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
loop {
if !src.is_empty() {
let n = session.writer().write(&src)?;
src.split_to(n);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
loop {
if !src.is_empty() {
src.split_to(session.writer().write(src)?);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
}
}
}
buf.set_src(Some(src));
Ok(())
} else {
Ok(())
}
})
}
}
@ -125,7 +123,6 @@ impl TlsClientFilter {
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let filter = TlsFilter::new_client(TlsClientFilter {
inner: IoInner {
pool: io.memory_pool(),
handshake: Cell::new(true),
},
session: RefCell::new(session),

View file

@ -2,7 +2,6 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::{any, cell::Cell, cmp, io, sync::Arc, task::Context, task::Poll};
use ntex_bytes::PoolRef;
use ntex_io::{
Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf,
WriteStatus,
@ -71,7 +70,7 @@ impl FilterLayer for TlsFilter {
}
#[inline]
fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result<Poll<()>> {
fn shutdown(&self, buf: &WriteBuf<'_>) -> io::Result<Poll<()>> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.shutdown(buf),
InnerTlsFilter::Client(ref f) => f.shutdown(buf),
@ -95,7 +94,7 @@ impl FilterLayer for TlsFilter {
}
#[inline]
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_read_buf(buf),
@ -103,7 +102,7 @@ impl FilterLayer for TlsFilter {
}
#[inline]
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_write_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_write_buf(buf),
@ -219,30 +218,31 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
}
pub(crate) struct IoInner {
pool: PoolRef,
handshake: Cell<bool>,
}
pub(crate) struct Wrapper<'a, 'b>(&'a IoInner, &'a mut WriteBuf<'b>);
pub(crate) struct Wrapper<'a, 'b>(&'a IoInner, &'a WriteBuf<'b>);
impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
self.1.with_read_buf(|buf| {
let read_buf = buf.get_src();
let len = cmp::min(read_buf.len(), dst.len());
if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len)
} else {
buf.with_src(|buf| {
if let Some(buf) = buf {
let len = cmp::min(buf.len(), dst.len());
if len > 0 {
dst[..len].copy_from_slice(&buf.split_to(len));
return Ok(len);
}
}
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
})
})
}
}
impl<'a, 'b> io::Write for Wrapper<'a, 'b> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
self.1.with_dst_buf(|buf| buf.extend_from_slice(src));
self.1.with_dst(|buf| buf.extend_from_slice(src));
Ok(src.len())
}

View file

@ -63,60 +63,60 @@ impl FilterLayer for TlsServerFilter {
}
}
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
let mut session = self.session.borrow_mut();
let mut new_bytes = usize::from(self.inner.handshake.get());
// get processed buffer
let (src, dst) = buf.get_pair();
let mut new_bytes = usize::from(self.inner.handshake.get());
loop {
// make sure we've got room
self.inner.pool.resize_read_buf(dst);
buf.with_src(|src| {
if let Some(src) = src {
buf.with_dst(|dst| {
loop {
let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?;
src.split_to(n);
let state = session
.process_new_packets()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?;
src.split_to(n);
let state = session
.process_new_packets()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let new_b = state.plaintext_bytes_to_read();
if new_b > 0 {
dst.reserve(new_b);
let chunk: &mut [u8] =
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let v = session.reader().read(chunk)?;
unsafe { dst.advance_mut(v) };
new_bytes += v;
} else {
break;
let new_b = state.plaintext_bytes_to_read();
if new_b > 0 {
dst.reserve(new_b);
let chunk: &mut [u8] =
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let v = session.reader().read(chunk)?;
unsafe { dst.advance_mut(v) };
new_bytes += v;
} else {
break;
}
}
Ok::<_, io::Error>(())
})?;
}
}
Ok(new_bytes)
Ok(new_bytes)
})
}
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
if let Some(mut src) = buf.take_src() {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
buf.with_src(|src| {
if let Some(src) = src {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
loop {
if !src.is_empty() {
let n = session.writer().write(&src)?;
src.split_to(n);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
loop {
if !src.is_empty() {
src.split_to(session.writer().write(src)?);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
}
}
}
buf.set_src(Some(src));
}
Ok(())
Ok(())
})
}
}
@ -132,7 +132,6 @@ impl TlsServerFilter {
let filter = TlsFilter::new_server(TlsServerFilter {
session: RefCell::new(session),
inner: IoInner {
pool: io.memory_pool(),
handshake: Cell::new(true),
},
});