Process write buffer if filter wrote to write buffer during reading

This commit is contained in:
Nikolay Kim 2023-01-24 08:31:26 +01:00
parent dd2dda09d1
commit dec6fd3dd8
14 changed files with 132 additions and 86 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.2.2] - 2023-01-24
* Process write buffer if filter wrote to write buffer during reading
## [0.2.1] - 2023-01-23
* Refactor Io and Filter types

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "0.2.1"
version = "0.2.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
@ -30,4 +30,4 @@ smallvec = "1"
rand = "0.8"
env_logger = "0.10"
ntex = { version = "0.6.0", features = ["tokio"] }
ntex = { version = "0.6.1", features = ["tokio"] }

View file

@ -38,6 +38,7 @@ impl Stack {
nbytes,
curr: &mut curr[idx],
next: &mut next[0],
need_write: false,
};
f(&mut buf)
} else {
@ -49,6 +50,7 @@ impl Stack {
nbytes,
curr: &mut val1,
next: &mut val2,
need_write: false,
};
let result = f(&mut buf);
@ -69,6 +71,7 @@ impl Stack {
io,
curr: &mut curr[idx],
next: &mut next[0],
need_write: false,
};
f(&mut buf)
} else {
@ -79,6 +82,7 @@ impl Stack {
io,
curr: &mut val1,
next: &mut val2,
need_write: false,
};
let result = f(&mut buf);
@ -152,6 +156,7 @@ pub struct ReadBuf<'a> {
pub(crate) curr: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) next: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) nbytes: usize,
pub(crate) need_write: bool,
}
impl<'a> ReadBuf<'a> {
@ -254,8 +259,11 @@ impl<'a> ReadBuf<'a> {
io: self.io,
curr: self.curr,
next: self.next,
need_write: self.need_write,
};
f(&mut buf)
let result = f(&mut buf);
self.need_write = buf.need_write;
result
}
}
@ -264,6 +272,7 @@ pub struct WriteBuf<'a> {
pub(crate) io: &'a IoRef,
pub(crate) curr: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) next: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) need_write: bool,
}
impl<'a> WriteBuf<'a> {
@ -302,11 +311,17 @@ impl<'a> WriteBuf<'a> {
#[inline]
/// Get reference to destination write buffer
pub fn get_dst(&mut self) -> &mut BytesVec {
pub fn with_dst_buf<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut BytesVec) -> R,
{
if self.next.1.is_none() {
self.next.1 = Some(self.io.memory_pool().get_write_buf());
}
self.next.1.as_mut().unwrap()
let buf = self.next.1.as_mut().unwrap();
let r = f(buf);
self.need_write |= !buf.is_empty();
r
}
#[inline]
@ -328,23 +343,12 @@ impl<'a> WriteBuf<'a> {
if let Some(b) = self.next.1.take() {
self.io.memory_pool().release_write_buf(b);
}
self.need_write |= !dst.is_empty();
self.next.1 = Some(dst);
}
}
}
#[inline]
/// Get reference to source and destination buffers (src, dst)
pub fn get_pair(&mut self) -> (&mut BytesVec, &mut BytesVec) {
if self.curr.1.is_none() {
self.curr.1 = Some(self.io.memory_pool().get_write_buf());
}
if self.next.1.is_none() {
self.next.1 = Some(self.io.memory_pool().get_write_buf());
}
(self.curr.1.as_mut().unwrap(), self.next.1.as_mut().unwrap())
}
#[inline]
pub fn with_read_buf<'b, F, R>(&'b mut self, f: F) -> R
where
@ -355,7 +359,10 @@ impl<'a> WriteBuf<'a> {
curr: self.curr,
next: self.next,
nbytes: 0,
need_write: self.need_write,
};
f(&mut buf)
let result = f(&mut buf);
self.need_write = buf.need_write;
result
}
}

View file

@ -29,6 +29,12 @@ impl NullFilter {
}
}
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
pub struct FilterReadStatus {
pub nbytes: usize,
pub need_write: bool,
}
pub trait Filter: 'static {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>>;
@ -38,7 +44,7 @@ pub trait Filter: 'static {
stack: &mut Stack,
idx: usize,
nbytes: usize,
) -> io::Result<usize>;
) -> io::Result<FilterReadStatus>;
/// Process write buffer
fn process_write_buf(
@ -112,8 +118,11 @@ impl Filter for Base {
_: &mut Stack,
_: usize,
nbytes: usize,
) -> io::Result<usize> {
Ok(nbytes)
) -> io::Result<FilterReadStatus> {
Ok(FilterReadStatus {
nbytes,
need_write: false,
})
}
#[inline]
@ -167,13 +176,18 @@ where
stack: &mut Stack,
idx: usize,
nbytes: usize,
) -> io::Result<usize> {
let nbytes = if F::BUFFERS {
) -> io::Result<FilterReadStatus> {
let status = if F::BUFFERS {
self.1.process_read_buf(io, stack, idx + 1, nbytes)?
} else {
self.1.process_read_buf(io, stack, idx, nbytes)?
};
stack.read_buf(io, idx, nbytes, |buf| self.0.process_read_buf(buf))
stack.read_buf(io, idx, status.nbytes, |buf| {
self.0.process_read_buf(buf).map(|nbytes| FilterReadStatus {
nbytes,
need_write: status.need_write || buf.need_write,
})
})
}
#[inline]
@ -254,8 +268,8 @@ impl Filter for NullFilter {
_: &mut Stack,
_: usize,
_: usize,
) -> io::Result<usize> {
Ok(0)
) -> io::Result<FilterReadStatus> {
Ok(Default::default())
}
#[inline]

View file

@ -406,7 +406,7 @@ mod tests {
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
self.write_order.borrow_mut().push(self.idx);
self.out_bytes
.set(self.out_bytes.get() + buf.get_dst().len());
.set(self.out_bytes.get() + buf.with_dst_buf(|b| b.len()));
Ok(())
}
}

View file

@ -24,6 +24,7 @@ impl ReadContext {
F: FnOnce(&mut BytesVec, usize, usize) -> Poll<io::Result<()>>,
{
let mut stack = self.0 .0.buffer.borrow_mut();
let is_write_sleep = stack.last_write_buf_size() == 0;
let mut buf = stack
.last_read_buf()
.take()
@ -50,8 +51,8 @@ impl ReadContext {
.filter()
.process_read_buf(&self.0, &mut stack, 0, nbytes)
{
Ok(nbytes) => {
if nbytes > 0 {
Ok(status) => {
if status.nbytes > 0 {
if buf_full || stack.first_read_buf_size() >= hw {
log::trace!(
"io read buffer is too large {}, enable read back-pressure",
@ -70,8 +71,27 @@ impl ReadContext {
);
} else if buf_full {
// read task is paused because of read back-pressure
// but there is no new data in top most read buffer
// so we need to wake up read task to read more data
// otherwise read task would sleep forever
self.0 .0.read_task.wake();
}
// while reading, filter wrote some data
// in that case filters need to process write buffers
// and potentialy wake write task
if status.need_write {
if let Err(err) =
self.0.filter().process_write_buf(&self.0, &mut stack, 0)
{
self.0 .0.dispatch_task.wake();
self.0 .0.insert_flags(Flags::RD_READY);
self.0 .0.init_shutdown(Some(err), &self.0);
}
if is_write_sleep && stack.last_write_buf_size() != 0 {
self.0 .0.write_task.wake();
}
}
}
Err(err) => {
self.0 .0.dispatch_task.wake();

View file

@ -182,7 +182,7 @@ mod tests {
NullFilter
.process_read_buf(&ioref, &mut stack, 0, 0)
.unwrap(),
(0)
Default::default()
)
}
}

View file

@ -1,5 +1,9 @@
# Changes
## [0.2.2] - 2023-01-24
* Update ntex-io to 0.2.2
## [0.2.1] - 2023-01-23
* Update filter implementation

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "0.2.1"
version = "0.2.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
@ -26,7 +26,7 @@ rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1.19"
ntex-io = "0.2.1"
ntex-io = "0.2.2"
ntex-util = "0.2.0"
ntex-service = "1.0.0"
log = "0.4"
@ -39,7 +39,7 @@ tls_openssl = { version="0.10", package = "openssl", optional = true }
tls_rust = { version = "0.20", package = "rustls", optional = true }
[dev-dependencies]
ntex = { version = "0.6.0", features = ["openssl", "rustls", "tokio"] }
ntex = { version = "0.6.1", features = ["openssl", "rustls", "tokio"] }
env_logger = "0.10"
rustls-pemfile = { version = "1.0" }
webpki-roots = { version = "0.22" }

View file

@ -29,20 +29,20 @@ pub struct SslFilter {
}
struct IoInner {
inner_read_buf: Option<BytesVec>,
inner_write_buf: Option<BytesVec>,
source: Option<BytesVec>,
destination: Option<BytesVec>,
pool: PoolRef,
}
impl io::Read for IoInner {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut buf) = self.inner_read_buf.take() {
if let Some(mut buf) = self.source.take() {
if buf.is_empty() {
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let len = cmp::min(buf.len(), dst.len());
dst[..len].copy_from_slice(&buf.split_to(len));
self.inner_read_buf = Some(buf);
self.source = Some(buf);
Ok(len)
}
} else {
@ -53,14 +53,14 @@ impl io::Read for IoInner {
impl io::Write for IoInner {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.inner_write_buf.take() {
let mut buf = if let Some(mut buf) = self.destination.take() {
buf.reserve(src.len());
buf
} else {
BytesVec::with_capacity_in(src.len(), self.pool)
};
buf.extend_from_slice(src);
self.inner_write_buf = Some(buf);
self.destination = Some(buf);
Ok(src.len())
}
@ -74,28 +74,22 @@ impl SslFilter {
where
F: FnOnce() -> R,
{
self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().inner_read_buf =
buf.with_read_buf(|b| b.take_src());
self.inner.borrow_mut().get_mut().destination = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().source = buf.with_read_buf(|b| b.take_src());
let result = f();
buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take());
buf.with_read_buf(|b| {
b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take())
});
buf.set_dst(self.inner.borrow_mut().get_mut().destination.take());
buf.with_read_buf(|b| b.set_src(self.inner.borrow_mut().get_mut().source.take()));
result
}
fn set_buffers(&self, buf: &mut WriteBuf<'_>) {
self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().inner_read_buf =
buf.with_read_buf(|b| b.take_src());
self.inner.borrow_mut().get_mut().destination = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().source = buf.with_read_buf(|b| b.take_src());
}
fn unset_buffers(&self, buf: &mut WriteBuf<'_>) {
buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take());
buf.with_read_buf(|b| {
b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take())
});
buf.set_dst(self.inner.borrow_mut().get_mut().destination.take());
buf.with_read_buf(|b| b.set_src(self.inner.borrow_mut().get_mut().source.take()));
}
}
@ -171,7 +165,8 @@ impl FilterLayer for SslFilter {
buf.with_write_buf(|b| self.set_buffers(b));
let dst = buf.get_dst();
let mut new_bytes = usize::from(self.handshake.get());
//let mut new_bytes = usize::from(self.handshake.get());
let mut new_bytes = 1;
loop {
// make sure we've got room
self.pool.resize_read_buf(dst);
@ -228,11 +223,7 @@ impl FilterLayer for SslFilter {
return match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => {
buf.set_dst(
self.inner
.borrow_mut()
.get_mut()
.inner_write_buf
.take(),
self.inner.borrow_mut().get_mut().destination.take(),
);
Ok(())
}
@ -294,8 +285,8 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
let ssl = ctx_result.map_err(map_to_ioerr)?;
let inner = IoInner {
pool: io.memory_pool(),
inner_read_buf: None,
inner_write_buf: None,
source: None,
destination: None,
};
let filter = SslFilter {
pool: io.memory_pool(),
@ -352,8 +343,8 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
Box::pin(async move {
let inner = IoInner {
pool: io.memory_pool(),
inner_read_buf: None,
inner_write_buf: None,
source: None,
destination: None,
};
let filter = SslFilter {
pool: io.memory_pool(),

View file

@ -242,7 +242,7 @@ impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
impl<'a, 'b> io::Write for Wrapper<'a, 'b> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
self.1.get_dst().extend_from_slice(src);
self.1.with_dst_buf(|buf| buf.extend_from_slice(src));
Ok(src.len())
}

View file

@ -1,5 +1,9 @@
# Changes
## [0.6.2] - 2023-01-24
* Update ntex-io, ntex-tls deps
## [0.6.1] - 2023-01-23
* Refactor io subsystem

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.6.1"
version = "0.6.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"
@ -56,10 +56,10 @@ ntex-service = "1.0.0"
ntex-macros = "0.1.3"
ntex-util = "0.2.0"
ntex-bytes = "0.1.19"
ntex-h2 = "0.2.0"
ntex-h2 = "0.2.1"
ntex-rt = "0.4.7"
ntex-io = "0.2.1"
ntex-tls = "0.2.1"
ntex-io = "0.2.2"
ntex-tls = "0.2.2"
ntex-tokio = { version = "0.2.1", optional = true }
ntex-glommio = { version = "0.2.1", optional = true }
ntex-async-std = { version = "0.2.1", optional = true }

View file

@ -67,13 +67,15 @@ impl FilterLayer for WsTransport {
} else {
CloseCode::Normal
};
let _ = self.codec.encode_vec(
Message::Close(Some(CloseReason {
code,
description: None,
})),
buf.get_dst(),
);
let _ = buf.with_dst_buf(|buf| {
self.codec.encode_vec(
Message::Close(Some(CloseReason {
code,
description: None,
})),
buf,
)
});
}
Ok(Poll::Ready(()))
}
@ -131,7 +133,7 @@ impl FilterLayer for WsTransport {
}
Frame::Ping(msg) => {
let _ = buf.with_write_buf(|b| {
self.codec.encode_vec(Message::Pong(msg), b.get_dst())
b.with_dst_buf(|b| self.codec.encode_vec(Message::Pong(msg), b))
});
}
Frame::Pong(_) => (),
@ -153,17 +155,17 @@ impl FilterLayer for WsTransport {
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
if let Some(src) = buf.take_src() {
let dst = buf.get_dst();
buf.with_dst_buf(|dst| {
// make sure we've got room
let (hw, lw) = self.pool.write_params().unpack();
let remaining = dst.remaining_mut();
if remaining < lw {
dst.reserve(cmp::max(hw, dst.len() + 12) - remaining);
}
// make sure we've got room
let (hw, lw) = self.pool.write_params().unpack();
let remaining = dst.remaining_mut();
if remaining < lw {
dst.reserve(cmp::max(hw, dst.len() + 12) - remaining);
}
// Encoder ws::Codec do not fail
let _ = self.codec.encode_vec(Message::Binary(src.freeze()), dst);
// Encoder ws::Codec do not fail
let _ = self.codec.encode_vec(Message::Binary(src.freeze()), dst);
});
}
Ok(())
}