Process write buffer if filter wrote to write buffer during reading

This commit is contained in:
Nikolay Kim 2023-01-24 08:31:26 +01:00
parent dd2dda09d1
commit dec6fd3dd8
14 changed files with 132 additions and 86 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.2.2] - 2023-01-24
* Update ntex-io to 0.2.2
## [0.2.1] - 2023-01-23
* Update filter implementation

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "0.2.1"
version = "0.2.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
@ -26,7 +26,7 @@ rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1.19"
ntex-io = "0.2.1"
ntex-io = "0.2.2"
ntex-util = "0.2.0"
ntex-service = "1.0.0"
log = "0.4"
@ -39,7 +39,7 @@ tls_openssl = { version="0.10", package = "openssl", optional = true }
tls_rust = { version = "0.20", package = "rustls", optional = true }
[dev-dependencies]
ntex = { version = "0.6.0", features = ["openssl", "rustls", "tokio"] }
ntex = { version = "0.6.1", features = ["openssl", "rustls", "tokio"] }
env_logger = "0.10"
rustls-pemfile = { version = "1.0" }
webpki-roots = { version = "0.22" }

View file

@ -29,20 +29,20 @@ pub struct SslFilter {
}
struct IoInner {
inner_read_buf: Option<BytesVec>,
inner_write_buf: Option<BytesVec>,
source: Option<BytesVec>,
destination: Option<BytesVec>,
pool: PoolRef,
}
impl io::Read for IoInner {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut buf) = self.inner_read_buf.take() {
if let Some(mut buf) = self.source.take() {
if buf.is_empty() {
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let len = cmp::min(buf.len(), dst.len());
dst[..len].copy_from_slice(&buf.split_to(len));
self.inner_read_buf = Some(buf);
self.source = Some(buf);
Ok(len)
}
} else {
@ -53,14 +53,14 @@ impl io::Read for IoInner {
impl io::Write for IoInner {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.inner_write_buf.take() {
let mut buf = if let Some(mut buf) = self.destination.take() {
buf.reserve(src.len());
buf
} else {
BytesVec::with_capacity_in(src.len(), self.pool)
};
buf.extend_from_slice(src);
self.inner_write_buf = Some(buf);
self.destination = Some(buf);
Ok(src.len())
}
@ -74,28 +74,22 @@ impl SslFilter {
where
F: FnOnce() -> R,
{
self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().inner_read_buf =
buf.with_read_buf(|b| b.take_src());
self.inner.borrow_mut().get_mut().destination = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().source = buf.with_read_buf(|b| b.take_src());
let result = f();
buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take());
buf.with_read_buf(|b| {
b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take())
});
buf.set_dst(self.inner.borrow_mut().get_mut().destination.take());
buf.with_read_buf(|b| b.set_src(self.inner.borrow_mut().get_mut().source.take()));
result
}
fn set_buffers(&self, buf: &mut WriteBuf<'_>) {
self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().inner_read_buf =
buf.with_read_buf(|b| b.take_src());
self.inner.borrow_mut().get_mut().destination = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().source = buf.with_read_buf(|b| b.take_src());
}
fn unset_buffers(&self, buf: &mut WriteBuf<'_>) {
buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take());
buf.with_read_buf(|b| {
b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take())
});
buf.set_dst(self.inner.borrow_mut().get_mut().destination.take());
buf.with_read_buf(|b| b.set_src(self.inner.borrow_mut().get_mut().source.take()));
}
}
@ -171,7 +165,8 @@ impl FilterLayer for SslFilter {
buf.with_write_buf(|b| self.set_buffers(b));
let dst = buf.get_dst();
let mut new_bytes = usize::from(self.handshake.get());
//let mut new_bytes = usize::from(self.handshake.get());
let mut new_bytes = 1;
loop {
// make sure we've got room
self.pool.resize_read_buf(dst);
@ -228,11 +223,7 @@ impl FilterLayer for SslFilter {
return match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => {
buf.set_dst(
self.inner
.borrow_mut()
.get_mut()
.inner_write_buf
.take(),
self.inner.borrow_mut().get_mut().destination.take(),
);
Ok(())
}
@ -294,8 +285,8 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
let ssl = ctx_result.map_err(map_to_ioerr)?;
let inner = IoInner {
pool: io.memory_pool(),
inner_read_buf: None,
inner_write_buf: None,
source: None,
destination: None,
};
let filter = SslFilter {
pool: io.memory_pool(),
@ -352,8 +343,8 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
Box::pin(async move {
let inner = IoInner {
pool: io.memory_pool(),
inner_read_buf: None,
inner_write_buf: None,
source: None,
destination: None,
};
let filter = SslFilter {
pool: io.memory_pool(),

View file

@ -242,7 +242,7 @@ impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
impl<'a, 'b> io::Write for Wrapper<'a, 'b> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
self.1.get_dst().extend_from_slice(src);
self.1.with_dst_buf(|buf| buf.extend_from_slice(src));
Ok(src.len())
}