Fix read filters ordering

This commit is contained in:
Nikolay Kim 2021-12-25 21:21:34 +06:00
parent 531bafbae2
commit 079a1c9cbf
23 changed files with 642 additions and 325 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.1.0-b.6] - 2021-12-xx
* Fix read filters ordering
## [0.1.0-b.5] - 2021-12-24
* Use new ntex-service traits

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "0.1.0-b.5"
version = "0.1.0-b.6"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]

View file

@ -67,7 +67,7 @@ impl Filter for Base {
if flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
Poll::Ready(ReadStatus::Terminate)
} else if flags.intersects(Flags::RD_PAUSED) {
} else if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) {
self.0 .0.read_task.register(cx.waker());
Poll::Pending
} else {
@ -98,7 +98,7 @@ impl Filter for Base {
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
self.0 .0.read_buf.take()
None
}
#[inline]
@ -107,21 +107,18 @@ impl Filter for Base {
}
#[inline]
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error> {
if nbytes > 0 {
if buf.len() > self.0.memory_pool().read_params().high as usize {
log::trace!(
"buffer is too large {}, enable read back-pressure",
buf.len()
);
self.0 .0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
} else {
self.0 .0.insert_flags(Flags::RD_READY);
}
fn release_read_buf(
&self,
buf: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> io::Result<usize> {
if let Some(ref mut dst) = dst {
dst.extend_from_slice(&buf)
} else {
*dst = Some(buf)
}
self.0 .0.read_buf.set(Some(buf));
Ok(())
Ok(nbytes)
}
#[inline]
@ -178,8 +175,13 @@ impl Filter for NullFilter {
None
}
fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<(), io::Error> {
Ok(())
fn release_read_buf(
&self,
_: BytesMut,
_: &mut Option<BytesMut>,
_: usize,
) -> io::Result<usize> {
Ok(0)
}
fn release_write_buf(&self, _: BytesMut) -> Result<(), io::Error> {

View file

@ -279,7 +279,8 @@ impl fmt::Debug for IoRef {
#[cfg(test)]
mod tests {
use std::{cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
use std::cell::{Cell, RefCell};
use std::{future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
use ntex_bytes::Bytes;
use ntex_codec::BytesCodec;
@ -393,9 +394,12 @@ mod tests {
}
struct Counter<F> {
idx: usize,
inner: F,
in_bytes: Rc<Cell<usize>>,
out_bytes: Rc<Cell<usize>>,
read_order: Rc<RefCell<Vec<usize>>>,
write_order: Rc<RefCell<Vec<usize>>>,
}
impl<F: Filter> Filter for Counter<F> {
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
@ -425,10 +429,13 @@ mod tests {
fn release_read_buf(
&self,
buf: BytesMut,
dst: &mut Option<BytesMut>,
new_bytes: usize,
) -> Result<(), io::Error> {
self.in_bytes.set(self.in_bytes.get() + new_bytes);
self.inner.release_read_buf(buf, new_bytes)
) -> io::Result<usize> {
let result = self.inner.release_read_buf(buf, dst, new_bytes)?;
self.read_order.borrow_mut().push(self.idx);
self.in_bytes.set(self.in_bytes.get() + result);
Ok(result)
}
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
@ -445,12 +452,19 @@ mod tests {
}
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
self.write_order.borrow_mut().push(self.idx);
self.out_bytes.set(self.out_bytes.get() + buf.len());
self.inner.release_write_buf(buf)
}
}
struct CounterFactory(Rc<Cell<usize>>, Rc<Cell<usize>>);
struct CounterFactory(
usize,
Rc<Cell<usize>>,
Rc<Cell<usize>>,
Rc<RefCell<Vec<usize>>>,
Rc<RefCell<Vec<usize>>>,
);
impl<F: Filter> FilterFactory<F> for CounterFactory {
type Filter = Counter<F>;
@ -459,14 +473,20 @@ mod tests {
type Future = Ready<Io<Counter<F>>, Self::Error>;
fn create(self, io: Io<F>) -> Self::Future {
let in_bytes = self.0.clone();
let out_bytes = self.1.clone();
let idx = self.0;
let in_bytes = self.1.clone();
let out_bytes = self.2.clone();
let read_order = self.3.clone();
let write_order = self.4.clone();
Ready::Ok(
io.map_filter(|inner| {
Ok::<_, ()>(Counter {
idx,
inner,
in_bytes,
out_bytes,
read_order,
write_order,
})
})
.unwrap(),
@ -478,7 +498,15 @@ mod tests {
async fn filter() {
let in_bytes = Rc::new(Cell::new(0));
let out_bytes = Rc::new(Cell::new(0));
let factory = CounterFactory(in_bytes.clone(), out_bytes.clone());
let read_order = Rc::new(RefCell::new(Vec::new()));
let write_order = Rc::new(RefCell::new(Vec::new()));
let factory = CounterFactory(
1,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
);
let (client, server) = IoTest::create();
let state = Io::new(server).add_filter(factory).await.unwrap();
@ -503,13 +531,27 @@ mod tests {
async fn boxed_filter() {
let in_bytes = Rc::new(Cell::new(0));
let out_bytes = Rc::new(Cell::new(0));
let read_order = Rc::new(RefCell::new(Vec::new()));
let write_order = Rc::new(RefCell::new(Vec::new()));
let (client, server) = IoTest::create();
let state = Io::new(server)
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
.add_filter(CounterFactory(
1,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
))
.await
.unwrap()
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
.add_filter(CounterFactory(
2,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
))
.await
.unwrap();
let state = state.seal();
@ -533,5 +575,7 @@ mod tests {
assert_eq!(Rc::strong_count(&in_bytes), 3);
drop(state);
assert_eq!(Rc::strong_count(&in_bytes), 1);
assert_eq!(*read_order.borrow(), &[1, 2][..]);
assert_eq!(*write_order.borrow(), &[2, 1][..]);
}
}

View file

@ -1,6 +1,6 @@
use std::{
any::Any, any::TypeId, fmt, future::Future, io::Error as IoError, task::Context,
task::Poll,
any::Any, any::TypeId, fmt, future::Future, io as sio, io::Error as IoError,
task::Context, task::Poll,
};
pub mod testing;
@ -55,7 +55,7 @@ pub trait Filter: 'static {
/// Filter wants gracefully shutdown io stream
fn want_shutdown(&self);
fn poll_shutdown(&self) -> Poll<std::io::Result<()>>;
fn poll_shutdown(&self) -> Poll<sio::Result<()>>;
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
@ -65,11 +65,16 @@ pub trait Filter: 'static {
fn get_write_buf(&self) -> Option<BytesMut>;
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> std::io::Result<()>;
fn release_read_buf(
&self,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> sio::Result<usize>;
fn release_write_buf(&self, buf: BytesMut) -> std::io::Result<()>;
fn release_write_buf(&self, buf: BytesMut) -> sio::Result<()>;
fn closed(&self, err: Option<std::io::Error>);
fn closed(&self, err: Option<sio::Error>);
}
pub trait FilterFactory<F: Filter>: Sized {

View file

@ -31,27 +31,35 @@ impl ReadContext {
}
#[inline]
pub fn release_read_buf(
&self,
buf: BytesMut,
new_bytes: usize,
) -> Result<(), io::Error> {
pub fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error> {
if buf.is_empty() {
self.0.memory_pool().release_read_buf(buf);
Ok(())
} else {
let mut flags = self.0.flags();
if new_bytes > 0 {
flags.insert(Flags::RD_READY);
self.0.set_flags(flags);
let mut dst = self.0 .0.read_buf.take();
let nbytes = self.0.filter().release_read_buf(buf, &mut dst, nbytes)?;
if let Some(dst) = dst {
if self.0.flags().contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters()?;
}
if nbytes > 0 {
if dst.len() > self.0.memory_pool().read_params().high as usize {
log::trace!(
"buffer is too large {}, enable read back-pressure",
dst.len()
);
self.0 .0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
} else {
self.0 .0.insert_flags(Flags::RD_READY);
log::trace!("new {} bytes available, wakeup dispatcher", nbytes);
}
self.0 .0.dispatch_task.wake();
}
self.0 .0.read_buf.set(Some(dst));
} else if nbytes > 0 {
self.0 .0.dispatch_task.wake();
log::trace!("new {} bytes available, wakeup dispatcher", new_bytes);
}
self.0.filter().release_read_buf(buf, new_bytes)?;
if flags.contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters()?;
self.0 .0.insert_flags(Flags::RD_READY);
}
Ok(())
}

View file

@ -51,62 +51,72 @@ impl Future for ReadTask {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(ReadStatus::Ready) => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
loop {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
if let Err(e) = this.state.release_read_buf(buf, new_bytes)
{
this.state.close(Some(e));
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => {
pending = true;
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("tokio stream is disconnected");
close = true;
} else {
this.state.close(None);
new_bytes += n;
if new_bytes < hw {
continue;
}
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
}
}
if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
this.state.close(Some(e));
Poll::Ready(())
} else {
Poll::Pending
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
return if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
this.state.close(Some(e));
Poll::Ready(())
} else if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
}
}
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
@ -466,63 +476,72 @@ mod unixstream {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(ReadStatus::Ready) => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
loop {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
if let Err(e) =
this.state.release_read_buf(buf, new_bytes)
{
this.state.close(Some(e));
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => {
pending = true;
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("unix stream is disconnected");
close = true;
} else {
this.state.close(None);
new_bytes += n;
if new_bytes < hw {
continue;
}
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
}
}
if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
this.state.close(Some(e));
Poll::Ready(())
} else {
Poll::Pending
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
return if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
this.state.close(Some(e));
Poll::Ready(())
} else if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
}
}
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "0.1.0-b.3"
version = "0.1.0-b.4"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
@ -26,7 +26,7 @@ rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1.8"
ntex-io = "0.1.0-b.5"
ntex-io = "0.1.0-b.6"
ntex-util = "0.1.4"
ntex-service = "0.3.0-b.0"
pin-project-lite = "0.2"

View file

@ -1,6 +1,6 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::cell::RefCell;
use std::cell::{Cell, RefCell};
use std::{
any, cmp, error::Error, future::Future, io, pin::Pin, task::Context, task::Poll,
};
@ -18,6 +18,7 @@ use super::types;
/// An implementation of SSL streams
pub struct SslFilter<F = Base> {
inner: RefCell<SslStream<IoInner<F>>>,
handshake: Cell<bool>,
}
struct IoInner<F> {
@ -128,12 +129,15 @@ impl<F: Filter> Filter for SslFilter<F> {
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() {
if !buf.is_empty() {
return Some(buf);
let mut inner = self.inner.borrow_mut();
inner.get_ref().inner.get_read_buf().or_else(|| {
if let Some(buf) = inner.get_mut().read_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
}
None
None
})
}
#[inline]
@ -146,25 +150,34 @@ impl<F: Filter> Filter for SslFilter<F> {
None
}
fn release_read_buf(&self, src: BytesMut, nbytes: usize) -> Result<(), io::Error> {
fn release_read_buf(
&self,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> io::Result<usize> {
// store to read_buf
let pool = {
let mut inner = self.inner.borrow_mut();
inner.get_mut().read_buf = Some(src);
inner.get_ref().pool
let mut dst = None;
inner
.get_ref()
.inner
.release_read_buf(src, &mut dst, nbytes)?;
if dst.is_some() {
inner.get_mut().read_buf = dst;
inner.get_ref().pool
} else {
return Ok(0);
}
};
if nbytes == 0 {
return Ok(());
}
let (hw, lw) = pool.read_params().unpack();
// get inner filter buffer
let mut buf = if let Some(buf) = self.inner.borrow().get_ref().inner.get_read_buf()
{
buf
} else {
BytesMut::with_capacity_in(lw, pool)
};
if dst.is_none() {
*dst = Some(pool.get_read_buf());
}
let buf = dst.as_mut().unwrap();
let mut new_bytes = 0;
loop {
@ -186,11 +199,13 @@ impl<F: Filter> Filter for SslFilter<F> {
if e.code() == ssl::ErrorCode::WANT_READ
|| e.code() == ssl::ErrorCode::WANT_WRITE =>
{
self.inner
.borrow()
.get_ref()
.inner
.release_read_buf(buf, new_bytes)
if new_bytes == 0 && self.handshake.get() {
new_bytes = 1;
if self.inner.borrow().ssl().is_init_finished() {
self.handshake.set(false);
}
}
Ok(new_bytes)
}
Err(e) => Err(map_to_ioerr(e)),
};
@ -274,6 +289,7 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
Ok::<_, Box<dyn Error>>(SslFilter {
inner: RefCell::new(ssl_stream),
handshake: Cell::new(true),
})
})?;
@ -326,6 +342,7 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
Ok::<_, Box<dyn Error>>(SslFilter {
inner: RefCell::new(ssl_stream),
handshake: Cell::new(true),
})
})?;

View file

@ -78,12 +78,15 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
if !buf.is_empty() {
return Some(buf);
let mut inner = self.inner.borrow_mut();
inner.inner.get_read_buf().or_else(|| {
if let Some(buf) = inner.read_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
}
None
None
})
}
#[inline]
@ -96,24 +99,36 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
None
}
fn release_read_buf(&self, mut src: BytesMut, _nb: usize) -> Result<(), io::Error> {
fn release_read_buf(
&self,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> io::Result<usize> {
let mut inner = self.inner.borrow_mut();
let mut session = self.session.borrow_mut();
if session.is_handshaking() {
self.inner.borrow_mut().read_buf = Some(src);
Ok(())
inner.read_buf = Some(src);
Ok(1)
} else {
if src.is_empty() {
return Ok(());
}
let mut inner = self.inner.borrow_mut();
let mut src = {
let mut dst = None;
inner.inner.release_read_buf(src, &mut dst, nbytes)?;
if let Some(dst) = dst {
dst
} else {
return Ok(0);
}
};
let (hw, lw) = inner.pool.read_params().unpack();
// get inner filter buffer
let mut buf = if let Some(buf) = inner.inner.get_read_buf() {
buf
} else {
BytesMut::with_capacity_in(lw, inner.pool)
};
if dst.is_none() {
*dst = Some(inner.pool.get_read_buf());
}
let buf = dst.as_mut().unwrap();
let mut new_bytes = 0;
loop {
@ -146,7 +161,7 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
if !src.is_empty() {
inner.read_buf = Some(src);
}
inner.inner.release_read_buf(buf, new_bytes)
Ok(new_bytes)
}
}

View file

@ -125,10 +125,15 @@ impl<F: Filter> Filter for TlsFilter<F> {
}
#[inline]
fn release_read_buf(&self, src: BytesMut, nb: usize) -> Result<(), io::Error> {
fn release_read_buf(
&self,
src: BytesMut,
dst: &mut Option<BytesMut>,
nb: usize,
) -> io::Result<usize> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_read_buf(src, nb),
InnerTlsFilter::Client(ref f) => f.release_read_buf(src, nb),
InnerTlsFilter::Server(ref f) => f.release_read_buf(src, dst, nb),
InnerTlsFilter::Client(ref f) => f.release_read_buf(src, dst, nb),
}
}

View file

@ -78,12 +78,15 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
if !buf.is_empty() {
return Some(buf);
let mut inner = self.inner.borrow_mut();
inner.inner.get_read_buf().or_else(|| {
if let Some(buf) = inner.read_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
}
None
None
})
}
#[inline]
@ -96,25 +99,36 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
None
}
fn release_read_buf(&self, mut src: BytesMut, _nb: usize) -> Result<(), io::Error> {
let mut session = self.session.borrow_mut();
fn release_read_buf(
&self,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> io::Result<usize> {
let mut inner = self.inner.borrow_mut();
let mut session = self.session.borrow_mut();
if session.is_handshaking() {
inner.read_buf = Some(src);
Ok(())
Ok(1)
} else {
if src.is_empty() {
return Ok(());
}
let mut src = {
let mut dst = None;
inner.inner.release_read_buf(src, &mut dst, nbytes)?;
if let Some(dst) = dst {
dst
} else {
return Ok(0);
}
};
let (hw, lw) = inner.pool.read_params().unpack();
// get inner filter buffer
let mut buf = if let Some(buf) = inner.inner.get_read_buf() {
buf
} else {
BytesMut::with_capacity_in(lw, inner.pool)
};
if dst.is_none() {
*dst = Some(inner.pool.get_read_buf());
}
let buf = dst.as_mut().unwrap();
let mut new_bytes = 0;
loop {
@ -147,7 +161,7 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
if !src.is_empty() {
inner.read_buf = Some(src);
}
inner.inner.release_read_buf(buf, new_bytes)
Ok(new_bytes)
}
}

View file

@ -1,5 +1,9 @@
# Changes
## [0.5.0-b.4] - 2021-12-xx
* Allow to get access to ws transport codec
## [0.5.0-b.3] - 2021-12-24
* Use new ntex-service traits

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.5.0-b.3"
version = "0.5.0-b.4"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"
@ -45,8 +45,8 @@ ntex-service = "0.3.0-b.0"
ntex-macros = "0.1.3"
ntex-util = "0.1.4"
ntex-bytes = "0.1.8"
ntex-tls = "0.1.0-b.3"
ntex-io = "0.1.0-b.5"
ntex-tls = "0.1.0-b.4"
ntex-io = "0.1.0-b.6"
ntex-rt = { version = "0.4.0-b.2", default-features = false, features = ["tokio"] }
base64 = "0.13"

View file

@ -154,11 +154,9 @@ mod tests {
let factory = Connector::new(config).clone();
let srv = factory.new_service(()).await.unwrap();
let _result = srv
let result = srv
.call(Connect::new("www.rust-lang.org").set_addr(Some(server.addr())))
.await;
// TODO! fix
// assert!(result.is_err());
assert!(result.is_err());
}
}

View file

@ -340,7 +340,9 @@ impl TestServer {
#[cfg(feature = "openssl")]
/// Connect to a websocket server
pub async fn wss(&mut self) -> Result<WsConnection<impl Filter>, WsClientError> {
pub async fn wss(
&mut self,
) -> Result<WsConnection<crate::connect::openssl::SslFilter>, WsClientError> {
self.wss_at("/").await
}
@ -349,7 +351,7 @@ impl TestServer {
pub async fn wss_at(
&mut self,
path: &str,
) -> Result<WsConnection<impl Filter>, WsClientError> {
) -> Result<WsConnection<crate::connect::openssl::SslFilter>, WsClientError> {
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();

View file

@ -81,7 +81,9 @@ pub mod io {
pub mod testing {
//! IO testing utilities.
#[doc(hidden)]
pub use ntex_io::testing::IoTest as Io;
pub use ntex_io::testing::IoTest;
}
pub mod tls {

View file

@ -57,12 +57,7 @@ struct Inner<F, T> {
impl WsClient<Base, ()> {
/// Create new websocket client builder
pub fn build<U>(
uri: U,
) -> WsClientBuilder<
Base,
impl Service<Connect<Uri>, Response = Io, Error = ConnectError>,
>
pub fn build<U>(uri: U) -> WsClientBuilder<Base, Connector<Uri>>
where
Uri: TryFrom<U>,
<Uri as TryFrom<U>>::Error: Into<HttpError>,
@ -269,12 +264,7 @@ impl<F, T> fmt::Debug for WsClient<F, T> {
impl WsClientBuilder<Base, ()> {
/// Create new websocket connector
fn new<U>(
uri: U,
) -> WsClientBuilder<
Base,
impl Service<Connect<Uri>, Response = Io, Error = ConnectError>,
>
fn new<U>(uri: U) -> WsClientBuilder<Base, Connector<Uri>>
where
Uri: TryFrom<U>,
<Uri as TryFrom<U>>::Error: Into<HttpError>,
@ -673,6 +663,11 @@ impl<F> WsConnection<F> {
Self { io, codec, res }
}
/// Get codec reference
pub fn codec(&self) -> &ws::Codec {
&self.codec
}
/// Get reference to response
pub fn response(&self) -> &ClientResponse {
&self.res

View file

@ -91,7 +91,7 @@ impl<F: Filter> Filter for WsTransport<F> {
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
self.read_buf.take()
self.inner.get_read_buf().or_else(|| self.read_buf.take())
}
#[inline]
@ -99,115 +99,127 @@ impl<F: Filter> Filter for WsTransport<F> {
None
}
fn release_read_buf(&self, mut src: BytesMut, nbytes: usize) -> Result<(), io::Error> {
if nbytes == 0 {
if !src.is_empty() {
self.read_buf.set(Some(src));
}
Ok(())
} else {
let (hw, lw) = self.pool.read_params().unpack();
fn release_read_buf(
&self,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> io::Result<usize> {
let mut src = {
let mut dst = None;
self.inner.release_read_buf(src, &mut dst, nbytes)?;
// get inner filter buffer
let mut buf = if let Some(buf) = self.inner.get_read_buf() {
buf
if let Some(dst) = dst {
dst
} else {
self.pool.get_read_buf()
};
let len = buf.len();
let mut flags = self.flags.get();
return Ok(0);
}
};
let (hw, lw) = self.pool.read_params().unpack();
// read from input buffer
loop {
let result = self.codec.decode(&mut src).map_err(|e| {
log::trace!("ws codec failed to decode bytes stream: {:?}", e);
io::Error::new(io::ErrorKind::Other, e)
})?;
// get outter filter buffer
if dst.is_none() {
*dst = Some(self.pool.get_read_buf());
}
let buf = dst.as_mut().unwrap();
let mut flags = self.flags.get();
let mut nbytes = 0;
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
// read from input buffer
loop {
let result = self.codec.decode(&mut src).map_err(|e| {
log::trace!("ws codec failed to decode bytes stream: {:?}", e);
io::Error::new(io::ErrorKind::Other, e)
})?;
match result {
Some(frame) => match frame {
Frame::Binary(bin) => buf.extend_from_slice(&bin),
Frame::Continuation(item) => match item {
Item::FirstText(_) => {
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSocket text continuation frames are not supported",
));
}
Item::FirstBinary(bin) => {
flags = self.insert_flags(Flags::CONTINUATION);
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match result {
Some(frame) => match frame {
Frame::Binary(bin) => {
nbytes += bin.len();
buf.extend_from_slice(&bin)
}
Frame::Continuation(item) => match item {
Item::FirstText(_) => {
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSocket text continuation frames are not supported",
));
}
Item::FirstBinary(bin) => {
nbytes += bin.len();
buf.extend_from_slice(&bin);
flags = self.insert_flags(Flags::CONTINUATION);
}
Item::Continue(bin) => {
if flags.contains(Flags::CONTINUATION) {
nbytes += bin.len();
buf.extend_from_slice(&bin);
}
Item::Continue(bin) => {
if flags.contains(Flags::CONTINUATION) {
buf.extend_from_slice(&bin);
} else {
return Err(io::Error::new(
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
"Continuation frame must follow data frame with FIN bit clear",
));
}
}
Item::Last(bin) => {
if flags.contains(Flags::CONTINUATION) {
flags = self.remove_flags(Flags::CONTINUATION);
buf.extend_from_slice(&bin);
} else {
return Err(io::Error::new(
}
Item::Last(bin) => {
if flags.contains(Flags::CONTINUATION) {
nbytes += bin.len();
buf.extend_from_slice(&bin);
flags = self.remove_flags(Flags::CONTINUATION);
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
"Received last frame without initial continuation frame",
));
}
}
},
Frame::Text(_) => {
log::trace!("WebSocket text frames are not supported");
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSocket text frames are not supported",
));
}
Frame::Ping(msg) => {
let mut b = self
.inner
.get_write_buf()
.unwrap_or_else(|| self.pool.get_write_buf());
self.codec
.encode(Message::Pong(msg), &mut b)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.release_write_buf(b)?;
}
Frame::Pong(_) => (),
Frame::Close(_) => {
let mut b = self
.inner
.get_write_buf()
.unwrap_or_else(|| self.pool.get_write_buf());
self.codec
.encode(Message::Close(None), &mut b)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.release_write_buf(b)?;
break;
}
},
None => break,
}
Frame::Text(_) => {
log::trace!("WebSocket text frames are not supported");
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSocket text frames are not supported",
));
}
Frame::Ping(msg) => {
let mut b = self
.inner
.get_write_buf()
.unwrap_or_else(|| self.pool.get_write_buf());
self.codec
.encode(Message::Pong(msg), &mut b)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.release_write_buf(b)?;
}
Frame::Pong(_) => (),
Frame::Close(_) => {
let mut b = self
.inner
.get_write_buf()
.unwrap_or_else(|| self.pool.get_write_buf());
self.codec
.encode(Message::Close(None), &mut b)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
self.release_write_buf(b)?;
break;
}
},
None => break,
}
if !src.is_empty() {
self.read_buf.set(Some(src));
} else {
self.pool.release_read_buf(src);
}
let new_bytes = buf.len() - len;
self.inner.release_read_buf(buf, new_bytes)
}
if !src.is_empty() {
self.read_buf.set(Some(src));
} else {
self.pool.release_read_buf(src);
}
Ok(nbytes)
}
fn release_write_buf(&self, src: BytesMut) -> Result<(), io::Error> {
@ -253,3 +265,39 @@ impl<F: Filter> FilterFactory<F> for WsTransportFactory {
Ready::from(st.map_filter(|inner| Ok(WsTransport::new(inner, self.codec, pool))))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{codec::BytesCodec, io::Io, testing::IoTest, util::Bytes};
#[crate::rt_test]
async fn basics() {
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
server.remote_buffer_cap(1024);
let client = WsTransportFactory::new(Codec::new().client_mode())
.create(Io::new(client))
.await
.unwrap();
let server = WsTransportFactory::new(Codec::new())
.create(Io::new(server))
.await
.unwrap();
client
.send(&BytesCodec, Bytes::from_static(b"DATA"))
.await
.unwrap();
let res = server.recv(&BytesCodec).await.unwrap().unwrap();
assert_eq!(res, b"DATA".as_ref());
server
.send(&BytesCodec, Bytes::from_static(b"DATA"))
.await
.unwrap();
let res = client.recv(&BytesCodec).await.unwrap().unwrap();
assert_eq!(res, b"DATA".as_ref());
}
}

View file

@ -5,13 +5,16 @@ use futures::future::{err, ok, ready};
use futures::stream::{once, Stream, StreamExt};
use tls_openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
use ntex::codec::BytesCodec;
use ntex::http::error::PayloadError;
use ntex::http::header::{self, HeaderName, HeaderValue};
use ntex::http::test::server as test_server;
use ntex::http::{body, HttpService, Method, Request, Response, StatusCode, Version};
use ntex::http::ws::handshake_response;
use ntex::http::{body, h1, HttpService, Method, Request, Response, StatusCode, Version};
use ntex::io::Io;
use ntex::service::{fn_service, ServiceFactory};
use ntex::util::{Bytes, BytesMut};
use ntex::{time::Seconds, web::error::InternalError};
use ntex::util::{Bytes, BytesMut, Ready};
use ntex::{time::Seconds, web::error::InternalError, ws};
async fn load_body<S>(stream: S) -> Result<BytesMut, PayloadError>
where
@ -435,3 +438,49 @@ async fn test_ssl_handshake_timeout() {
let _ = stream.read_to_string(&mut data);
assert!(data.is_empty());
}
#[ntex::test]
async fn test_ws_transport() {
let mut srv = test_server(|| {
HttpService::build()
.upgrade(|(req, io, codec): (Request, Io<_>, h1::Codec)| {
async move {
let res = handshake_response(req.head()).finish();
// send handshake respone
io.encode(
h1::Message::Item((res.drop_body(), body::BodySize::None)),
&codec,
)
.unwrap();
let io = io
.add_filter(ws::WsTransportFactory::new(ws::Codec::default()))
.await?;
// start websocket service
loop {
if let Some(item) =
io.recv(&BytesCodec).await.map_err(|e| e.into_inner())?
{
io.send(&BytesCodec, item.freeze()).await.unwrap()
} else {
break;
}
}
Ok::<_, io::Error>(())
}
})
.finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound()))
.openssl(ssl_acceptor())
});
let io = srv.wss().await.unwrap().into_inner().0;
let codec = ws::Codec::default().client_mode();
io.send(&codec, ws::Message::Binary(Bytes::from_static(b"text")))
.await
.unwrap();
let item = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text")));
}

View file

@ -2,6 +2,9 @@ use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::{cell::Cell, future::Future, io, pin::Pin};
use ntex::codec::BytesCodec;
use ntex::http::test::server as test_server;
use ntex::http::ws::handshake_response;
use ntex::http::{
body, h1, test, ws::handshake, HttpService, Request, Response, StatusCode,
};
@ -239,3 +242,49 @@ async fn test_simple() {
assert!(ws_service.was_polled());
}
#[ntex::test]
async fn test_transport() {
let mut srv = test_server(|| {
HttpService::build()
.upgrade(|(req, io, codec): (Request, Io, h1::Codec)| {
async move {
let res = handshake_response(req.head()).finish();
// send handshake respone
io.encode(
h1::Message::Item((res.drop_body(), body::BodySize::None)),
&codec,
)
.unwrap();
let io = io
.add_filter(ws::WsTransportFactory::new(ws::Codec::default()))
.await?;
// start websocket service
loop {
if let Some(item) =
io.recv(&BytesCodec).await.map_err(|e| e.into_inner())?
{
io.send(&BytesCodec, item.freeze()).await.unwrap()
} else {
break;
}
}
Ok::<_, io::Error>(())
}
})
.finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound()))
});
// client service
let io = srv.ws().await.unwrap().into_inner().0;
let codec = ws::Codec::default().client_mode();
io.send(&codec, ws::Message::Binary(Bytes::from_static(b"text")))
.await
.unwrap();
let item = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text")));
}

View file

@ -1,5 +1,6 @@
use std::io;
use ntex::codec::BytesCodec;
use ntex::http::test::server as test_server;
use ntex::http::ws::handshake_response;
use ntex::http::{body::BodySize, h1, HttpService, Request, Response};
@ -79,3 +80,39 @@ async fn test_simple() {
let item = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into())));
}
#[ntex::test]
async fn test_transport() {
env_logger::init();
let mut srv = test_server(|| {
HttpService::build()
.upgrade(|(req, io, codec): (Request, Io, h1::Codec)| {
async move {
let res = handshake_response(req.head()).finish();
// send handshake respone
io.encode(h1::Message::Item((res.drop_body(), BodySize::None)), &codec)
.unwrap();
// start websocket service
Dispatcher::new(
io.seal(),
ws::Codec::default(),
ws_service,
Default::default(),
)
.await
}
})
.finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound()))
});
// client service
let io = srv.ws().await.unwrap().into_transport().await;
io.send(&BytesCodec, Bytes::from_static(b"text"))
.await
.unwrap();
let item = io.recv(&BytesCodec).await.unwrap().unwrap();
assert_eq!(item, Bytes::from_static(b"text"));
}

View file

@ -4,7 +4,7 @@ use std::{sync::mpsc, thread, time::Duration};
use tls_openssl::ssl::SslAcceptorBuilder;
use ntex::web::{self, App, HttpResponse, HttpServer};
use ntex::{rt, server::TestServer, time::Seconds};
use ntex::{rt, server::TestServer, time::sleep, time::Seconds};
#[cfg(unix)]
#[ntex::test]
@ -203,7 +203,7 @@ async fn test_rustls() {
// stop
let _ = srv.stop(false);
thread::sleep(Duration::from_millis(100));
sleep(Duration::from_millis(100)).await;
sys.stop();
}
@ -253,7 +253,7 @@ async fn test_bind_uds() {
// stop
let _ = srv.stop(false);
thread::sleep(Duration::from_millis(100));
sleep(Duration::from_millis(100)).await;
sys.stop();
}
@ -305,6 +305,6 @@ async fn test_listen_uds() {
// stop
let _ = srv.stop(false);
thread::sleep(Duration::from_millis(100));
sleep(Duration::from_millis(100)).await;
sys.stop();
}