Refactor ntex-io (#164)

* Refactor Io and Filter types
This commit is contained in:
Nikolay Kim 2023-01-23 14:42:00 +06:00 committed by GitHub
parent 8cbd8758a5
commit 83d05d81ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 1615 additions and 1495 deletions

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-async-std"
version = "0.2.0"
version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "async-std intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"]
@ -16,8 +16,8 @@ name = "ntex_async_std"
path = "src/lib.rs"
[dependencies]
ntex-bytes = "0.1.11"
ntex-io = "0.2.0"
ntex-bytes = "0.1.19"
ntex-io = "0.2.1"
ntex-util = "0.2.0"
async-oneshot = "0.5.0"
log = "0.4"

View file

@ -1,4 +1,4 @@
use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll};
use std::{any, cell::RefCell, future::Future, io, pin::Pin, task::Context, task::Poll};
use async_std::io::{Read, Write};
use ntex_bytes::{Buf, BufMut, BytesVec};
@ -30,35 +30,31 @@ impl Handle for TcpStream {
/// Read io task
struct ReadTask {
io: TcpStream,
io: RefCell<TcpStream>,
state: ReadContext,
}
impl ReadTask {
/// Create new read io task
fn new(io: TcpStream, state: ReadContext) -> Self {
Self { io, state }
Self {
state,
io: RefCell::new(io),
}
}
}
impl Future for ReadTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
loop {
this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut buf = this.state.get_read_buf();
let io = &mut this.io;
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
let mut io = self.io.borrow_mut();
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
@ -66,52 +62,31 @@ impl Future for ReadTask {
buf.reserve(hw - remaining);
}
match poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) {
Poll::Pending => {
pending = true;
break;
}
return match poll_read_buf(Pin::new(&mut io.0), cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("async-std stream is disconnected");
close = true;
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
new_bytes += n;
if new_bytes <= hw {
continue;
}
Poll::Pending
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
log::trace!("async-std read task failed on io {:?}", err);
Poll::Ready(Err(err))
}
}
};
}
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
Poll::Ready(Ok(()))
}
}
}
})
}
}
@ -358,10 +333,6 @@ pub fn poll_read_buf<T: Read>(
cx: &mut Context<'_>,
buf: &mut BytesVec,
) -> Poll<io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) };
let n = ready!(io.poll_read(cx, dst))?;
@ -389,35 +360,31 @@ mod unixstream {
/// Read io task
struct ReadTask {
io: UnixStream,
io: RefCell<UnixStream>,
state: ReadContext,
}
impl ReadTask {
/// Create new read io task
fn new(io: UnixStream, state: ReadContext) -> Self {
Self { io, state }
Self {
state,
io: RefCell::new(io),
}
}
}
impl Future for ReadTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
loop {
this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut buf = this.state.get_read_buf();
let io = &mut this.io;
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
let mut io = this.io.borrow_mut();
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
@ -425,52 +392,31 @@ mod unixstream {
buf.reserve(hw - remaining);
}
match poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) {
Poll::Pending => {
pending = true;
break;
}
return match poll_read_buf(Pin::new(&mut io.0), cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("async-std stream is disconnected");
close = true;
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
new_bytes += n;
if new_bytes <= hw {
continue;
}
Poll::Pending
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
Poll::Ready(Err(err))
}
}
};
}
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
Poll::Ready(Ok(()))
}
}
}
})
}
}

View file

@ -1,5 +1,9 @@
# Changes
## [0.1.19] (2023-01-23)
* Add PollRef::resize_read_buf() and PollRef::resize_write_buf() helpers
## [0.1.18] (2022-12-13)
* Add Bytes<&Bytes> for Bytes impl

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-bytes"
version = "0.1.18"
version = "0.1.19"
license = "MIT"
authors = ["Nikolay Kim <fafhrd91@gmail.com>", "Carl Lerche <me@carllerche.com>"]
description = "Types and traits for working with bytes (bytes crate fork)"

View file

@ -6,7 +6,7 @@ use std::{cell::Cell, cell::RefCell, fmt, future::Future, mem, pin::Pin, ptr, rc
use futures_core::task::__internal::AtomicWaker;
use crate::{BytesMut, BytesVec};
use crate::{BufMut, BytesMut, BytesVec};
pub struct Pool {
idx: Cell<usize>,
@ -293,6 +293,17 @@ impl PoolRef {
}
}
#[doc(hidden)]
#[inline]
/// Resize read buffer
pub fn resize_read_buf(self, buf: &mut BytesVec) {
let (hw, lw) = self.0.write_wm.get().unpack();
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
}
#[doc(hidden)]
#[inline]
/// Release read buffer, buf must be allocated from this pool
@ -318,6 +329,17 @@ impl PoolRef {
}
}
#[doc(hidden)]
#[inline]
/// Resize write buffer
pub fn resize_write_buf(self, buf: &mut BytesVec) {
let (hw, lw) = self.0.write_wm.get().unpack();
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
}
#[doc(hidden)]
#[inline]
/// Release write buffer, buf must be allocated from this pool

View file

@ -1,5 +1,9 @@
# Changes
## [0.2.1] - 2023-01-23
* Use new Io object
## [0.2.0] - 2023-01-04
* Release

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-connect"
version = "0.2.0"
version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "ntexwork connect utils for ntex framework"
keywords = ["network", "framework", "async", "futures"]
@ -35,18 +35,18 @@ async-std = ["ntex-rt/async-std", "ntex-async-std"]
[dependencies]
ntex-service = "1.0.0"
ntex-bytes = "0.1.18"
ntex-bytes = "0.1.19"
ntex-http = "0.1.8"
ntex-io = "0.2.0"
ntex-io = "0.2.1"
ntex-rt = "0.4.7"
ntex-tls = "0.2.0"
ntex-tls = "0.2.1"
ntex-util = "0.2.0"
log = "0.4"
thiserror = "1.0"
ntex-tokio = { version = "0.2.0", optional = true }
ntex-glommio = { version = "0.2.0", optional = true }
ntex-async-std = { version = "0.2.0", optional = true }
ntex-tokio = { version = "0.2.1", optional = true }
ntex-glommio = { version = "0.2.1", optional = true }
ntex-async-std = { version = "0.2.1", optional = true }
# openssl
tls-openssl = { version="0.10", package = "openssl", optional = true }

View file

@ -4,7 +4,7 @@ pub use ntex_tls::openssl::SslFilter;
pub use tls_openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
use ntex_bytes::PoolId;
use ntex_io::{Base, Io};
use ntex_io::{FilterFactory, Io, Layer};
use ntex_service::{Service, ServiceFactory};
use ntex_tls::openssl::SslConnector as IoSslConnector;
use ntex_util::future::{BoxFuture, Ready};
@ -39,7 +39,7 @@ impl<T> Connector<T> {
impl<T: Address> Connector<T> {
/// Resolve and connect to remote host
pub async fn connect<U>(&self, message: U) -> Result<Io<SslFilter<Base>>, ConnectError>
pub async fn connect<U>(&self, message: U) -> Result<Io<Layer<SslFilter>>, ConnectError>
where
Connect<T>: From<U>,
{
@ -57,7 +57,7 @@ impl<T: Address> Connector<T> {
let ssl = config
.into_ssl(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
match io.add_filter(IoSslConnector::new(ssl)).await {
match IoSslConnector::new(ssl).create(io).await {
Ok(io) => {
trace!("SSL Handshake success: {:?}", host);
Ok(io)
@ -82,7 +82,7 @@ impl<T> Clone for Connector<T> {
}
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io<SslFilter<Base>>;
type Response = Io<Layer<SslFilter>>;
type Error = ConnectError;
type Service = Connector<T>;
type InitError = ();
@ -95,7 +95,7 @@ impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
}
impl<T: Address> Service<Connect<T>> for Connector<T> {
type Response = Io<SslFilter<Base>>;
type Response = Io<Layer<SslFilter>>;
type Error = ConnectError;
type Future<'f> = BoxFuture<'f, Result<Self::Response, Self::Error>>;

View file

@ -4,7 +4,7 @@ pub use ntex_tls::rustls::TlsFilter;
pub use tls_rustls::{ClientConfig, ServerName};
use ntex_bytes::PoolId;
use ntex_io::{Base, Io};
use ntex_io::{FilterFactory, Io, Layer};
use ntex_service::{Service, ServiceFactory};
use ntex_tls::rustls::TlsConnector;
use ntex_util::future::{BoxFuture, Ready};
@ -48,7 +48,7 @@ impl<T> Connector<T> {
impl<T: Address + 'static> Connector<T> {
/// Resolve and connect to remote host
pub async fn connect<U>(&self, message: U) -> Result<Io<TlsFilter<Base>>, ConnectError>
pub async fn connect<U>(&self, message: U) -> Result<Io<Layer<TlsFilter>>, ConnectError>
where
Connect<T>: From<U>,
{
@ -64,7 +64,7 @@ impl<T: Address + 'static> Connector<T> {
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;
let connector = connector.server_name(host.clone());
match io.add_filter(connector).await {
match connector.create(io).await {
Ok(io) => {
trace!("TLS Handshake success: {:?}", &host);
Ok(io)
@ -87,7 +87,7 @@ impl<T> Clone for Connector<T> {
}
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io<TlsFilter<Base>>;
type Response = Io<Layer<TlsFilter>>;
type Error = ConnectError;
type Service = Connector<T>;
type InitError = ();
@ -100,7 +100,7 @@ impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
}
impl<T: Address> Service<Connect<T>> for Connector<T> {
type Response = Io<TlsFilter<Base>>;
type Response = Io<Layer<TlsFilter>>;
type Error = ConnectError;
type Future<'f> = BoxFuture<'f, Result<Self::Response, Self::Error>>;

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-glommio"
version = "0.2.0"
version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "glommio intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"]
@ -16,8 +16,8 @@ name = "ntex_glommio"
path = "src/lib.rs"
[dependencies]
ntex-bytes = "0.1.18"
ntex-io = "0.2.0"
ntex-bytes = "0.1.19"
ntex-io = "0.2.1"
ntex-util = "0.2.0"
async-oneshot = "0.5.0"
futures-lite = "1.12"

View file

@ -57,17 +57,10 @@ impl Future for ReadTask {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut();
loop {
this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut buf = this.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
@ -75,56 +68,35 @@ impl Future for ReadTask {
buf.reserve(hw - remaining);
}
match poll_read_buf(
return match poll_read_buf(
Pin::new(&mut *this.io.0.borrow_mut()),
cx,
&mut buf,
buf,
) {
Poll::Pending => {
pending = true;
break;
}
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("glommio stream is disconnected");
close = true;
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
new_bytes += n;
if new_bytes <= hw {
continue;
}
Poll::Pending
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
Poll::Ready(Err(err))
}
}
};
}
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
Poll::Ready(Ok(()))
}
}
}
})
}
}
@ -372,10 +344,6 @@ pub fn poll_read_buf<T: AsyncRead>(
cx: &mut Context<'_>,
buf: &mut BytesVec,
) -> Poll<io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) };
let n = ready!(io.poll_read(cx, dst))?;
@ -407,17 +375,10 @@ impl Future for UnixReadTask {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut();
loop {
this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut buf = this.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
@ -425,56 +386,35 @@ impl Future for UnixReadTask {
buf.reserve(hw - remaining);
}
match poll_read_buf(
return match poll_read_buf(
Pin::new(&mut *this.io.0.borrow_mut()),
cx,
&mut buf,
buf,
) {
Poll::Pending => {
pending = true;
break;
}
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("glommio stream is disconnected");
close = true;
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
new_bytes += n;
if new_bytes <= hw {
continue;
}
Poll::Pending
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
Poll::Ready(Err(err))
}
}
};
}
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
Poll::Ready(Ok(()))
}
}
}
})
}
}

View file

@ -1,5 +1,9 @@
# Changes
## [0.2.1] - 2023-01-23
* Refactor Io and Filter types
## [0.2.0] - 2023-01-04
* Release

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "0.2.0"
version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
@ -17,13 +17,14 @@ path = "src/lib.rs"
[dependencies]
ntex-codec = "0.6.2"
ntex-bytes = "0.1.18"
ntex-bytes = "0.1.19"
ntex-util = "0.2.0"
ntex-service = "1.0.0"
bitflags = "1.3"
log = "0.4"
pin-project-lite = "0.2"
smallvec = "1"
[dev-dependencies]
rand = "0.8"

361
ntex-io/src/buf.rs Normal file
View file

@ -0,0 +1,361 @@
use ntex_bytes::{BytesVec, PoolRef};
use smallvec::SmallVec;
use crate::IoRef;
#[derive(Debug)]
pub struct Stack {
pub(crate) buffers: SmallVec<[(Option<BytesVec>, Option<BytesVec>); 4]>,
}
impl Stack {
pub(crate) fn new() -> Self {
let mut buffers = SmallVec::with_capacity(4);
buffers.push((None, None));
Self { buffers }
}
pub(crate) fn add_layer(&mut self) {
self.buffers.insert(0, (None, None));
}
pub(crate) fn read_buf<F, R>(
&mut self,
io: &IoRef,
idx: usize,
nbytes: usize,
f: F,
) -> R
where
F: FnOnce(&mut ReadBuf<'_>) -> R,
{
let pos = idx + 1;
if self.buffers.len() > pos {
let (curr, next) = self.buffers.split_at_mut(pos);
let mut buf = ReadBuf {
io,
nbytes,
curr: &mut curr[idx],
next: &mut next[0],
};
f(&mut buf)
} else {
let mut val1 = (self.buffers[idx].0.take(), None);
let mut val2 = (None, self.buffers[idx].1.take());
let mut buf = ReadBuf {
io,
nbytes,
curr: &mut val1,
next: &mut val2,
};
let result = f(&mut buf);
self.buffers[idx].0 = val1.0;
self.buffers[idx].1 = val2.1;
result
}
}
pub(crate) fn write_buf<F, R>(&mut self, io: &IoRef, idx: usize, f: F) -> R
where
F: FnOnce(&mut WriteBuf<'_>) -> R,
{
let pos = idx + 1;
if self.buffers.len() > pos {
let (curr, next) = self.buffers.split_at_mut(pos);
let mut buf = WriteBuf {
io,
curr: &mut curr[idx],
next: &mut next[0],
};
f(&mut buf)
} else {
let mut val1 = (self.buffers[idx].0.take(), None);
let mut val2 = (None, self.buffers[idx].1.take());
let mut buf = WriteBuf {
io,
curr: &mut val1,
next: &mut val2,
};
let result = f(&mut buf);
self.buffers[idx].0 = val1.0;
self.buffers[idx].1 = val2.1;
result
}
}
pub(crate) fn first_read_buf_size(&self) -> usize {
self.buffers[0].0.as_ref().map(|b| b.len()).unwrap_or(0)
}
pub(crate) fn first_read_buf(&mut self) -> &mut Option<BytesVec> {
&mut self.buffers[0].0
}
pub(crate) fn first_write_buf(&mut self, io: &IoRef) -> &mut BytesVec {
if self.buffers[0].1.is_none() {
self.buffers[0].1 = Some(io.memory_pool().get_write_buf());
}
self.buffers[0].1.as_mut().unwrap()
}
pub(crate) fn last_read_buf(&mut self) -> &mut Option<BytesVec> {
let idx = self.buffers.len() - 1;
&mut self.buffers[idx].0
}
pub(crate) fn last_write_buf(&mut self) -> &mut Option<BytesVec> {
let idx = self.buffers.len() - 1;
&mut self.buffers[idx].1
}
pub(crate) fn last_write_buf_size(&self) -> usize {
let idx = self.buffers.len() - 1;
self.buffers[idx].1.as_ref().map(|b| b.len()).unwrap_or(0)
}
pub(crate) fn set_last_write_buf(&mut self, buf: BytesVec) {
let idx = self.buffers.len() - 1;
self.buffers[idx].1 = Some(buf);
}
pub(crate) fn release(&mut self, pool: PoolRef) {
for buf in &mut self.buffers {
if let Some(buf) = buf.0.take() {
pool.release_read_buf(buf);
}
if let Some(buf) = buf.1.take() {
pool.release_write_buf(buf);
}
}
}
pub(crate) fn set_memory_pool(&mut self, pool: PoolRef) {
for buf in &mut self.buffers {
if let Some(ref mut b) = buf.0 {
pool.move_vec_in(b);
}
if let Some(ref mut b) = buf.1 {
pool.move_vec_in(b);
}
}
}
}
#[derive(Debug)]
pub struct ReadBuf<'a> {
pub(crate) io: &'a IoRef,
pub(crate) curr: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) next: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) nbytes: usize,
}
impl<'a> ReadBuf<'a> {
#[inline]
/// Get number of newly added bytes
pub fn nbytes(&self) -> usize {
self.nbytes
}
#[inline]
/// Initiate graceful io stream shutdown
pub fn want_shutdown(&self) {
self.io.want_shutdown()
}
#[inline]
/// Get reference to source read buffer
pub fn get_src(&mut self) -> &mut BytesVec {
if self.next.0.is_none() {
self.next.0 = Some(self.io.memory_pool().get_read_buf());
}
self.next.0.as_mut().unwrap()
}
#[inline]
/// Take source read buffer
pub fn take_src(&mut self) -> Option<BytesVec> {
self.next
.0
.take()
.and_then(|b| if b.is_empty() { None } else { Some(b) })
}
#[inline]
/// Set source read buffer
pub fn set_src(&mut self, src: Option<BytesVec>) {
if let Some(src) = src {
if src.is_empty() {
self.io.memory_pool().release_read_buf(src);
} else {
if let Some(b) = self.next.0.take() {
self.io.memory_pool().release_read_buf(b);
}
self.next.0 = Some(src);
}
}
}
#[inline]
/// Get reference to destination read buffer
pub fn get_dst(&mut self) -> &mut BytesVec {
if self.curr.0.is_none() {
self.curr.0 = Some(self.io.memory_pool().get_read_buf());
}
self.curr.0.as_mut().unwrap()
}
#[inline]
/// Take destination read buffer
pub fn take_dst(&mut self) -> BytesVec {
self.curr
.0
.take()
.unwrap_or_else(|| self.io.memory_pool().get_read_buf())
}
#[inline]
/// Set destination read buffer
pub fn set_dst(&mut self, dst: Option<BytesVec>) {
if let Some(dst) = dst {
if dst.is_empty() {
self.io.memory_pool().release_read_buf(dst);
} else {
if let Some(b) = self.curr.0.take() {
self.io.memory_pool().release_read_buf(b);
}
self.curr.0 = Some(dst);
}
}
}
#[inline]
/// Get reference to source and destination read buffers (src, dst)
pub fn get_pair(&mut self) -> (&mut BytesVec, &mut BytesVec) {
if self.next.0.is_none() {
self.next.0 = Some(self.io.memory_pool().get_read_buf());
}
if self.curr.0.is_none() {
self.curr.0 = Some(self.io.memory_pool().get_read_buf());
}
(self.next.0.as_mut().unwrap(), self.curr.0.as_mut().unwrap())
}
#[inline]
pub fn with_write_buf<'b, F, R>(&'b mut self, f: F) -> R
where
F: FnOnce(&mut WriteBuf<'b>) -> R,
{
let mut buf = WriteBuf {
io: self.io,
curr: self.curr,
next: self.next,
};
f(&mut buf)
}
}
#[derive(Debug)]
pub struct WriteBuf<'a> {
pub(crate) io: &'a IoRef,
pub(crate) curr: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) next: &'a mut (Option<BytesVec>, Option<BytesVec>),
}
impl<'a> WriteBuf<'a> {
#[inline]
/// Initiate graceful io stream shutdown
pub fn want_shutdown(&self) {
self.io.want_shutdown()
}
#[inline]
/// Get reference to source write buffer
pub fn get_src(&mut self) -> &mut BytesVec {
if self.curr.1.is_none() {
self.curr.1 = Some(self.io.memory_pool().get_write_buf());
}
self.curr.1.as_mut().unwrap()
}
#[inline]
/// Take source write buffer
pub fn take_src(&mut self) -> Option<BytesVec> {
self.curr
.1
.take()
.and_then(|b| if b.is_empty() { None } else { Some(b) })
}
#[inline]
/// Set source write buffer
pub fn set_src(&mut self, src: Option<BytesVec>) {
if let Some(b) = self.curr.1.take() {
self.io.memory_pool().release_read_buf(b);
}
self.curr.1 = src;
}
#[inline]
/// Get reference to destination write buffer
pub fn get_dst(&mut self) -> &mut BytesVec {
if self.next.1.is_none() {
self.next.1 = Some(self.io.memory_pool().get_write_buf());
}
self.next.1.as_mut().unwrap()
}
#[inline]
/// Take destination write buffer
pub fn take_dst(&mut self) -> BytesVec {
self.next
.1
.take()
.unwrap_or_else(|| self.io.memory_pool().get_write_buf())
}
#[inline]
/// Set destination write buffer
pub fn set_dst(&mut self, dst: Option<BytesVec>) {
if let Some(dst) = dst {
if dst.is_empty() {
self.io.memory_pool().release_write_buf(dst);
} else {
if let Some(b) = self.next.1.take() {
self.io.memory_pool().release_write_buf(b);
}
self.next.1 = Some(dst);
}
}
}
#[inline]
/// Get reference to source and destination buffers (src, dst)
pub fn get_pair(&mut self) -> (&mut BytesVec, &mut BytesVec) {
if self.curr.1.is_none() {
self.curr.1 = Some(self.io.memory_pool().get_write_buf());
}
if self.next.1.is_none() {
self.next.1 = Some(self.io.memory_pool().get_write_buf());
}
(self.curr.1.as_mut().unwrap(), self.next.1.as_mut().unwrap())
}
#[inline]
pub fn with_read_buf<'b, F, R>(&'b mut self, f: F) -> R
where
F: FnOnce(&mut ReadBuf<'b>) -> R,
{
let mut buf = ReadBuf {
io: self.io,
curr: self.curr,
next: self.next,
nbytes: 0,
};
f(&mut buf)
}
}

View file

@ -316,18 +316,15 @@ where
}
// shutdown service
DispatcherState::Shutdown => {
let err = slf.error.take();
return if this.inner.shared.service.poll_shutdown(cx).is_ready() {
log::trace!("service shutdown is completed, stop");
Poll::Ready(if let Some(err) = err {
Poll::Ready(if let Some(err) = slf.error.take() {
Err(err)
} else {
Ok(())
})
} else {
slf.error.set(err);
Poll::Pending
};
}
@ -632,9 +629,7 @@ mod tests {
// close read side
client.close().await;
// TODO! fix
// assert!(client.is_server_dropped());
assert!(client.is_server_dropped());
// service must be checked for readiness only once
assert_eq!(counter.get(), 1);

View file

@ -1,8 +1,6 @@
use std::{any, io, task::Context, task::Poll};
use ntex_bytes::BytesVec;
use super::{io::Flags, Filter, IoRef, ReadStatus, WriteStatus};
use super::{buf::Stack, io::Flags, FilterLayer, IoRef, ReadStatus, WriteStatus};
/// Default `Io` filter
pub struct Base(IoRef);
@ -13,8 +11,54 @@ impl Base {
}
}
pub struct Layer<F, L = Base>(pub(crate) F, L);
impl<F: FilterLayer, L: Filter> Layer<F, L> {
pub(crate) fn new(f: F, l: L) -> Self {
Self(f, l)
}
}
pub(crate) struct NullFilter;
const NULL: NullFilter = NullFilter;
impl NullFilter {
pub(super) fn get() -> &'static dyn Filter {
&NULL
}
}
pub trait Filter: 'static {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>>;
fn process_read_buf(
&self,
io: &IoRef,
stack: &mut Stack,
idx: usize,
nbytes: usize,
) -> io::Result<usize>;
/// Process write buffer
fn process_write_buf(
&self,
io: &IoRef,
stack: &mut Stack,
idx: usize,
) -> io::Result<()>;
/// Gracefully shutdown filter
fn shutdown(&self, io: &IoRef, stack: &mut Stack, idx: usize) -> io::Result<Poll<()>>;
/// Check readiness for read operations
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
/// Check readiness for write operations
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus>;
}
impl Filter for Base {
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
if let Some(hnd) = self.0 .0.handle.take() {
let res = hnd.query(id);
@ -25,11 +69,6 @@ impl Filter for Base {
}
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
let flags = self.0.flags();
@ -67,51 +106,128 @@ impl Filter for Base {
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
self.0 .0.read_buf.take()
fn process_read_buf(
&self,
_: &IoRef,
_: &mut Stack,
_: usize,
nbytes: usize,
) -> io::Result<usize> {
Ok(nbytes)
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
self.0 .0.write_buf.take()
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
self.0 .0.read_buf.set(Some(buf));
}
#[inline]
fn process_read_buf(&self, _: &IoRef, n: usize) -> io::Result<(usize, usize)> {
let buf = self.0 .0.read_buf.as_ptr();
let ref_buf = unsafe { buf.as_ref().unwrap() };
let total = ref_buf.as_ref().map(|v| v.len()).unwrap_or(0);
Ok((total, n))
}
#[inline]
fn release_write_buf(&self, buf: BytesVec) -> Result<(), io::Error> {
let pool = self.0.memory_pool();
if buf.is_empty() {
pool.release_write_buf(buf);
} else {
if buf.len() >= pool.write_params_high() {
fn process_write_buf(&self, _: &IoRef, s: &mut Stack, _: usize) -> io::Result<()> {
if let Some(buf) = s.last_write_buf() {
if buf.len() >= self.0.memory_pool().write_params_high() {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
}
self.0 .0.write_buf.set(Some(buf));
self.0 .0.write_task.wake();
}
Ok(())
}
#[inline]
fn shutdown(&self, _: &IoRef, _: &mut Stack, _: usize) -> io::Result<Poll<()>> {
Ok(Poll::Ready(()))
}
}
pub(crate) struct NullFilter;
impl<F, L> Filter for Layer<F, L>
where
F: FilterLayer,
L: Filter,
{
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
self.0.query(id).or_else(|| self.1.query(id))
}
const NULL: NullFilter = NullFilter;
#[inline]
fn shutdown(&self, io: &IoRef, stack: &mut Stack, idx: usize) -> io::Result<Poll<()>> {
let result1 = stack.write_buf(io, idx, |buf| self.0.shutdown(buf))?;
self.process_write_buf(io, stack, idx)?;
impl NullFilter {
pub(super) fn get() -> &'static dyn Filter {
&NULL
let result2 = if F::BUFFERS {
self.1.shutdown(io, stack, idx + 1)?
} else {
self.1.shutdown(io, stack, idx)?
};
if result1.is_pending() || result2.is_pending() {
Ok(Poll::Pending)
} else {
Ok(Poll::Ready(()))
}
}
#[inline]
fn process_read_buf(
&self,
io: &IoRef,
stack: &mut Stack,
idx: usize,
nbytes: usize,
) -> io::Result<usize> {
let nbytes = if F::BUFFERS {
self.1.process_read_buf(io, stack, idx + 1, nbytes)?
} else {
self.1.process_read_buf(io, stack, idx, nbytes)?
};
stack.read_buf(io, idx, nbytes, |buf| self.0.process_read_buf(buf))
}
#[inline]
fn process_write_buf(
&self,
io: &IoRef,
stack: &mut Stack,
idx: usize,
) -> io::Result<()> {
stack.write_buf(io, idx, |buf| self.0.process_write_buf(buf))?;
if F::BUFFERS {
self.1.process_write_buf(io, stack, idx + 1)
} else {
self.1.process_write_buf(io, stack, idx)
}
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
let res1 = self.0.poll_read_ready(cx);
let res2 = self.1.poll_read_ready(cx);
match res1 {
Poll::Pending => Poll::Pending,
Poll::Ready(ReadStatus::Ready) => res2,
Poll::Ready(ReadStatus::Terminate) => Poll::Ready(ReadStatus::Terminate),
}
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
let res1 = self.0.poll_write_ready(cx);
let res2 = self.1.poll_write_ready(cx);
match res1 {
Poll::Pending => Poll::Pending,
Poll::Ready(WriteStatus::Ready) => res2,
Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate),
Poll::Ready(WriteStatus::Shutdown(t)) => {
if res2 == Poll::Ready(WriteStatus::Terminate) {
Poll::Ready(WriteStatus::Terminate)
} else {
Poll::Ready(WriteStatus::Shutdown(t))
}
}
Poll::Ready(WriteStatus::Timeout(t)) => match res2 {
Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate),
Poll::Ready(WriteStatus::Shutdown(t)) => {
Poll::Ready(WriteStatus::Shutdown(t))
}
_ => Poll::Ready(WriteStatus::Timeout(t)),
},
}
}
}
@ -121,11 +237,6 @@ impl Filter for NullFilter {
None
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[inline]
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<ReadStatus> {
Poll::Ready(ReadStatus::Terminate)
@ -137,25 +248,23 @@ impl Filter for NullFilter {
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
None
fn process_read_buf(
&self,
_: &IoRef,
_: &mut Stack,
_: usize,
_: usize,
) -> io::Result<usize> {
Ok(0)
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
None
}
#[inline]
fn release_read_buf(&self, _: BytesVec) {}
#[inline]
fn process_read_buf(&self, _: &IoRef, _: usize) -> io::Result<(usize, usize)> {
Ok((0, 0))
}
#[inline]
fn release_write_buf(&self, _: BytesVec) -> Result<(), io::Error> {
fn process_write_buf(&self, _: &IoRef, _: &mut Stack, _: usize) -> io::Result<()> {
Ok(())
}
#[inline]
fn shutdown(&self, _: &IoRef, _: &mut Stack, _: usize) -> io::Result<Poll<()>> {
Ok(Poll::Ready(()))
}
}

View file

@ -1,4 +1,4 @@
use std::cell::Cell;
use std::cell::{Cell, RefCell};
use std::task::{Context, Poll};
use std::{fmt, future::Future, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc, time};
@ -7,10 +7,11 @@ use ntex_codec::{Decoder, Encoder};
use ntex_util::time::{now, Millis};
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker};
use super::filter::{Base, NullFilter};
use super::seal::Sealed;
use super::tasks::{ReadContext, WriteContext};
use super::{Filter, FilterFactory, Handle, IoStatusUpdate, IoStream, RecvError};
use crate::buf::Stack;
use crate::filter::{Base, Filter, Layer, NullFilter};
use crate::seal::Sealed;
use crate::tasks::{ReadContext, WriteContext};
use crate::{FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
bitflags::bitflags! {
pub struct Flags: u16 {
@ -59,8 +60,7 @@ pub(crate) struct IoState {
pub(super) read_task: LocalWaker,
pub(super) write_task: LocalWaker,
pub(super) dispatch_task: LocalWaker,
pub(super) read_buf: Cell<Option<BytesVec>>,
pub(super) write_buf: Cell<Option<BytesVec>>,
pub(super) buffer: RefCell<Stack>,
pub(super) filter: Cell<&'static dyn Filter>,
pub(super) handle: Cell<Option<Box<dyn Handle>>>,
#[allow(clippy::box_collection)]
@ -104,7 +104,6 @@ impl IoState {
}
}
#[inline]
pub(super) fn io_stopped(&self, err: Option<io::Error>) {
if err.is_some() {
self.error.set(err);
@ -119,9 +118,8 @@ impl IoState {
);
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, err: Option<io::Error>) {
pub(super) fn init_shutdown(&self, err: Option<io::Error>, io: &IoRef) {
if err.is_some() {
self.io_stopped(err);
} else if !self
@ -131,28 +129,25 @@ impl IoState {
{
log::trace!("initiate io shutdown {:?}", self.flags.get());
self.insert_flags(Flags::IO_STOPPING_FILTERS);
self.shutdown_filters();
self.shutdown_filters(io);
}
}
#[inline]
pub(super) fn shutdown_filters(&self) {
pub(super) fn shutdown_filters(&self, io: &IoRef) {
if !self
.flags
.get()
.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
{
match self.filter.get().poll_shutdown() {
Poll::Ready(Ok(())) => {
let mut buffer = self.buffer.borrow_mut();
match self.filter.get().shutdown(io, &mut buffer, 0) {
Ok(Poll::Ready(())) => {
self.read_task.wake();
self.write_task.wake();
self.dispatch_task.wake();
self.insert_flags(Flags::IO_STOPPING);
}
Poll::Ready(Err(err)) => {
self.io_stopped(Some(err));
}
Poll::Pending => {
Ok(Poll::Pending) => {
let flags = self.flags.get();
// check read buffer, if buffer is not consumed it is unlikely
// that filter will properly complete shutdown
@ -165,40 +160,37 @@ impl IoState {
self.insert_flags(Flags::IO_STOPPING);
}
}
}
Err(err) => {
self.io_stopped(Some(err));
}
};
self.write_task.wake();
}
}
#[inline]
pub(super) fn with_read_buf<Fn, Ret>(&self, release: bool, f: Fn) -> Ret
where
Fn: FnOnce(&mut Option<BytesVec>) -> Ret,
{
let filter = self.filter.get();
let mut buf = filter.get_read_buf();
let result = f(&mut buf);
// use top most buffer
let mut buffer = self.buffer.borrow_mut();
let buf = buffer.first_read_buf();
let result = f(buf);
if let Some(buf) = buf {
if release {
// release buffer
if buf.is_empty() {
self.pool.get().release_read_buf(buf);
return result;
}
// release buffer
if release && buf.as_ref().map(|b| b.is_empty()).unwrap_or(false) {
if let Some(b) = buf.take() {
self.pool.get().release_read_buf(b);
}
filter.release_read_buf(buf);
}
result
}
#[inline]
pub(super) fn with_write_buf<Fn, Ret>(&self, f: Fn) -> Ret
where
Fn: FnOnce(&mut Option<BytesVec>) -> Ret,
{
let buf = self.write_buf.as_ptr();
let ref_buf = unsafe { buf.as_mut().unwrap() };
f(ref_buf)
f(self.buffer.borrow_mut().last_write_buf())
}
}
@ -221,12 +213,7 @@ impl hash::Hash for IoState {
impl Drop for IoState {
#[inline]
fn drop(&mut self) {
if let Some(buf) = self.read_buf.take() {
self.pool.get().release_read_buf(buf);
}
if let Some(buf) = self.write_buf.take() {
self.pool.get().release_write_buf(buf);
}
self.buffer.borrow_mut().release(self.pool.get());
}
}
@ -248,8 +235,7 @@ impl Io {
dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(),
write_task: LocalWaker::new(),
read_buf: Cell::new(None),
write_buf: Cell::new(None),
buffer: RefCell::new(Stack::new()),
filter: Cell::new(NullFilter::get()),
handle: Cell::new(None),
on_disconnect: Cell::new(None),
@ -277,14 +263,7 @@ impl<F> Io<F> {
#[inline]
/// Set memory pool
pub fn set_memory_pool(&self, pool: PoolRef) {
if let Some(mut buf) = self.0 .0.read_buf.take() {
pool.move_vec_in(&mut buf);
self.0 .0.read_buf.set(Some(buf));
}
if let Some(mut buf) = self.0 .0.write_buf.take() {
pool.move_vec_in(&mut buf);
self.0 .0.write_buf.set(Some(buf));
}
self.0 .0.buffer.borrow_mut().set_memory_pool(pool);
self.0 .0.pool.set(pool);
}
@ -312,8 +291,7 @@ impl<F> Io<F> {
dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(),
write_task: LocalWaker::new(),
read_buf: Cell::new(None),
write_buf: Cell::new(None),
buffer: RefCell::new(Stack::new()),
filter: Cell::new(NullFilter::get()),
handle: Cell::new(None),
on_disconnect: Cell::new(None),
@ -353,57 +331,37 @@ impl<F> Io<F> {
}
impl<F: Filter> Io<F> {
#[inline]
/// Get referece to a filter
pub fn filter(&self) -> &F {
self.1.filter()
}
#[inline]
/// Convert current io stream into sealed version
pub fn seal(mut self) -> Io<Sealed> {
// get current filter
let filter = unsafe {
let filter = self.1.seal();
let filter_ref: &'static dyn Filter = {
let filter: &dyn Filter = filter.0.as_ref();
std::mem::transmute(filter)
};
self.0 .0.filter.replace(filter_ref);
filter
};
Io(self.0.clone(), FilterItem::with_sealed(filter))
}
#[inline]
/// Create new filter and replace current one
pub fn add_filter<T>(self, factory: T) -> T::Future
where
T: FilterFactory<F>,
{
factory.create(self)
let (filter, filter_ref) = self.1.seal();
self.0 .0.filter.replace(filter_ref);
Io(self.0.clone(), filter)
}
#[inline]
/// Map current filter with new one
pub fn map_filter<T, U, E>(mut self, map: U) -> Result<Io<T>, E>
pub fn add_filter<U>(mut self, nf: U) -> Io<Layer<U, F>>
where
T: Filter,
U: FnOnce(F) -> Result<T, E>,
U: FilterLayer,
{
// replace current filter
let filter = unsafe {
let filter = Box::new(map(*(self.1.get_filter()))?);
let filter_ref: &'static dyn Filter = {
let filter: &dyn Filter = filter.as_ref();
std::mem::transmute(filter)
};
self.0 .0.filter.replace(filter_ref);
filter
};
// add layer to buffers
if U::BUFFERS {
self.0 .0.buffer.borrow_mut().add_layer();
}
Ok(Io(self.0.clone(), FilterItem::with_filter(filter)))
// replace current filter
let (filter, filter_ref) = self.1.add_filter(nf);
self.0 .0.filter.replace(filter_ref);
Io(self.0.clone(), filter)
}
}
impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
#[inline]
/// Get referece to a filter
pub fn filter(&self) -> &F {
&self.1.filter().0
}
}
@ -629,8 +587,10 @@ impl<F> Io<F> {
}
} else {
if !flags.contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.init_shutdown(None);
self.0 .0.init_shutdown(None, &self.0);
}
self.0 .0.read_task.wake();
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
}
@ -759,20 +719,6 @@ impl<F> FilterItem<F> {
slf
}
fn with_sealed(f: Sealed) -> Self {
let mut slf = Self {
data: [0; SEALED_SIZE],
_t: marker::PhantomData,
};
unsafe {
let ptr = &mut slf.data as *mut _ as *mut Sealed;
ptr.write(f);
slf.data[KIND_IDX] |= KIND_SEALED;
}
slf
}
/// Get filter, panic if it is not filter
fn filter(&self) -> &F {
if self.data[KIND_IDX] & KIND_PTR != 0 {
@ -786,8 +732,8 @@ impl<F> FilterItem<F> {
}
}
/// Get filter, panic if it is not filter
fn get_filter(&mut self) -> Box<F> {
/// Get filter, panic if it is not set
fn take_filter(&mut self) -> Box<F> {
if self.data[KIND_IDX] & KIND_PTR != 0 {
self.data[KIND_IDX] &= KIND_UNMASK;
let ptr = &mut self.data as *mut _ as *mut *mut F;
@ -801,7 +747,7 @@ impl<F> FilterItem<F> {
}
/// Get sealed, panic if it is already sealed
fn get_sealed(&mut self) -> Sealed {
fn take_sealed(&mut self) -> Sealed {
if self.data[KIND_IDX] & KIND_SEALED != 0 {
self.data[KIND_IDX] &= KIND_UNMASK;
let ptr = &mut self.data as *mut _ as *mut Sealed;
@ -820,25 +766,54 @@ impl<F> FilterItem<F> {
fn drop_filter(&mut self) {
if self.data[KIND_IDX] & KIND_PTR != 0 {
self.get_filter();
self.take_filter();
} else if self.data[KIND_IDX] & KIND_SEALED != 0 {
self.get_sealed();
self.take_sealed();
}
}
}
impl<F: Filter> FilterItem<F> {
fn seal(&mut self) -> Sealed {
if self.data[KIND_IDX] & KIND_PTR != 0 {
Sealed(Box::new(*self.get_filter()))
fn add_filter<T: FilterLayer>(
&mut self,
new: T,
) -> (FilterItem<Layer<T, F>>, &'static dyn Filter) {
let filter = Box::new(Layer::new(new, *self.take_filter()));
let filter_ref: &'static dyn Filter = {
let filter: &dyn Filter = filter.as_ref();
unsafe { std::mem::transmute(filter) }
};
(FilterItem::with_filter(filter), filter_ref)
}
fn seal(&mut self) -> (FilterItem<Sealed>, &'static dyn Filter) {
let filter = if self.data[KIND_IDX] & KIND_PTR != 0 {
Sealed(Box::new(*self.take_filter()))
} else if self.data[KIND_IDX] & KIND_SEALED != 0 {
self.get_sealed()
self.take_sealed()
} else {
panic!(
"Wrong filter item {:?} expected: {:?}",
self.data[KIND_IDX], KIND_PTR
);
};
let filter_ref: &'static dyn Filter = {
let filter: &dyn Filter = filter.0.as_ref();
unsafe { std::mem::transmute(filter) }
};
let mut slf = FilterItem {
data: [0; SEALED_SIZE],
_t: marker::PhantomData,
};
unsafe {
let ptr = &mut slf.data as *mut _ as *mut Sealed;
ptr.write(filter);
slf.data[KIND_IDX] |= KIND_SEALED;
}
(slf, filter_ref)
}
}

View file

@ -1,9 +1,9 @@
use std::{any, fmt, hash, io, time};
use ntex_bytes::{BufMut, BytesVec, PoolRef};
use ntex_bytes::{BytesVec, PoolRef};
use ntex_codec::{Decoder, Encoder};
use super::{io::Flags, timer, types, Filter, IoRef, OnDisconnect};
use super::{io::Flags, timer, types, Filter, IoRef, OnDisconnect, WriteBuf};
impl IoRef {
#[inline]
@ -49,7 +49,7 @@ impl IoRef {
/// Notify dispatcher and initiate io stream shutdown process.
pub fn close(&self) {
self.0.insert_flags(Flags::DSP_STOP);
self.0.init_shutdown(None);
self.0.init_shutdown(None, self);
}
#[inline]
@ -72,8 +72,16 @@ impl IoRef {
#[inline]
/// Gracefully shutdown io stream
pub fn want_shutdown(&self, err: Option<io::Error>) {
self.0.init_shutdown(err);
pub fn want_shutdown(&self) {
if !self
.0
.flags
.get()
.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
{
log::trace!("initiate io shutdown {:?}", self.0.flags.get());
self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
}
}
#[inline]
@ -96,13 +104,8 @@ impl IoRef {
if !flags.contains(Flags::IO_STOPPING) {
self.with_write_buf(|buf| {
let (hw, lw) = self.memory_pool().write_params().unpack();
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
self.memory_pool().resize_write_buf(buf);
// encode item and wake write task
codec.encode_vec(item, buf)
@ -151,21 +154,35 @@ impl IoRef {
#[inline]
/// Get mut access to write buffer
pub fn with_write_buf<F, R>(&self, f: F) -> Result<R, io::Error>
pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
where
F: FnOnce(&mut BytesVec) -> R,
{
let filter = self.0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
let mut buffer = self.0.buffer.borrow_mut();
let is_write_sleep = buffer.last_write_buf_size() == 0;
let result = f(&mut buf);
if is_write_sleep {
let result = f(buffer.first_write_buf(self));
self.0
.filter
.get()
.process_write_buf(self, &mut buffer, 0)?;
if is_write_sleep && buffer.last_write_buf_size() != 0 {
self.0.write_task.wake();
}
filter.release_write_buf(buf)?;
Ok(result)
}
#[inline]
/// Get mut access to write buffer
pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
where
F: FnOnce(&mut WriteBuf<'_>) -> R,
{
let mut b = self.0.buffer.borrow_mut();
let result = b.write_buf(self, 0, f);
self.0.filter.get().process_write_buf(self, &mut b, 0)?;
self.0.write_task.wake();
Ok(result)
}
@ -240,16 +257,15 @@ impl fmt::Debug for IoRef {
#[cfg(test)]
mod tests {
use std::cell::{Cell, RefCell};
use std::{future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
use std::{future::Future, pin::Pin, rc::Rc, task::Poll};
use ntex_bytes::Bytes;
use ntex_codec::BytesCodec;
use ntex_util::future::{lazy, poll_fn, Ready};
use ntex_util::future::{lazy, poll_fn};
use ntex_util::time::{sleep, Millis};
use super::*;
use crate::testing::IoTest;
use crate::{Filter, FilterFactory, Io, ReadStatus, WriteStatus};
use crate::{testing::IoTest, FilterLayer, Io, ReadBuf, WriteBuf};
const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
@ -370,87 +386,28 @@ mod tests {
assert_eq!(waiter.await, ());
}
struct Counter<F> {
struct Counter {
idx: usize,
inner: F,
in_bytes: Rc<Cell<usize>>,
out_bytes: Rc<Cell<usize>>,
read_order: Rc<RefCell<Vec<usize>>>,
write_order: Rc<RefCell<Vec<usize>>>,
}
impl<F: Filter> Filter for Counter<F> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.poll_read_ready(cx)
}
fn get_read_buf(&self) -> Option<BytesVec> {
self.inner.get_read_buf()
}
impl FilterLayer for Counter {
const BUFFERS: bool = false;
fn release_read_buf(&self, buf: BytesVec) {
self.inner.release_read_buf(buf)
}
fn process_read_buf(&self, io: &IoRef, n: usize) -> io::Result<(usize, usize)> {
let result = self.inner.process_read_buf(io, n)?;
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
self.read_order.borrow_mut().push(self.idx);
self.in_bytes.set(self.in_bytes.get() + result.1);
Ok(result)
self.in_bytes.set(self.in_bytes.get() + buf.nbytes());
Ok(buf.nbytes())
}
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.poll_write_ready(cx)
}
fn get_write_buf(&self) -> Option<BytesVec> {
if let Some(buf) = self.inner.get_write_buf() {
self.out_bytes.set(self.out_bytes.get() - buf.len());
Some(buf)
} else {
None
}
}
fn release_write_buf(&self, buf: BytesVec) -> Result<(), io::Error> {
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
self.write_order.borrow_mut().push(self.idx);
self.out_bytes.set(self.out_bytes.get() + buf.len());
self.inner.release_write_buf(buf)
}
}
struct CounterFactory(
usize,
Rc<Cell<usize>>,
Rc<Cell<usize>>,
Rc<RefCell<Vec<usize>>>,
Rc<RefCell<Vec<usize>>>,
);
impl<F: Filter> FilterFactory<F> for CounterFactory {
type Filter = Counter<F>;
type Error = ();
type Future = Ready<Io<Counter<F>>, Self::Error>;
fn create(self, io: Io<F>) -> Self::Future {
let idx = self.0;
let in_bytes = self.1.clone();
let out_bytes = self.2.clone();
let read_order = self.3.clone();
let write_order = self.4;
Ready::Ok(
io.map_filter(|inner| {
Ok::<_, ()>(Counter {
idx,
inner,
in_bytes,
out_bytes,
read_order,
write_order,
})
})
.unwrap(),
)
self.out_bytes
.set(self.out_bytes.get() + buf.get_dst().len());
Ok(())
}
}
@ -460,24 +417,22 @@ mod tests {
let out_bytes = Rc::new(Cell::new(0));
let read_order = Rc::new(RefCell::new(Vec::new()));
let write_order = Rc::new(RefCell::new(Vec::new()));
let factory = CounterFactory(
1,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
);
let (client, server) = IoTest::create();
let state = Io::new(server).add_filter(factory).await.unwrap();
let io = Io::new(server).add_filter(Counter {
idx: 1,
in_bytes: in_bytes.clone(),
out_bytes: out_bytes.clone(),
read_order: read_order.clone(),
write_order: write_order.clone(),
});
client.remote_buffer_cap(1024);
client.write(TEXT);
let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
assert_eq!(msg, Bytes::from_static(BIN));
state
.send(Bytes::from_static(b"test"), &BytesCodec)
io.send(Bytes::from_static(b"test"), &BytesCodec)
.await
.unwrap();
let buf = client.read().await.unwrap();
@ -496,24 +451,20 @@ mod tests {
let (client, server) = IoTest::create();
let state = Io::new(server)
.add_filter(CounterFactory(
1,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
))
.await
.unwrap()
.add_filter(CounterFactory(
2,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
))
.await
.unwrap();
.add_filter(Counter {
idx: 1,
in_bytes: in_bytes.clone(),
out_bytes: out_bytes.clone(),
read_order: read_order.clone(),
write_order: write_order.clone(),
})
.add_filter(Counter {
idx: 2,
in_bytes: in_bytes.clone(),
out_bytes: out_bytes.clone(),
read_order: read_order.clone(),
write_order: write_order.clone(),
});
let state = state.seal();
client.remote_buffer_cap(1024);

View file

@ -7,6 +7,7 @@ use std::{
pub mod testing;
pub mod types;
mod buf;
mod dispatcher;
mod filter;
mod framed;
@ -17,12 +18,12 @@ mod tasks;
mod timer;
mod utils;
use ntex_bytes::BytesVec;
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::Millis;
pub use self::buf::{ReadBuf, WriteBuf};
pub use self::dispatcher::Dispatcher;
pub use self::filter::Base;
pub use self::filter::{Base, Filter, Layer};
pub use self::framed::Framed;
pub use self::io::{Io, IoRef, OnDisconnect};
pub use self::seal::{IoBoxed, Sealed};
@ -49,44 +50,51 @@ pub enum WriteStatus {
Terminate,
}
pub trait Filter: 'static {
fn query(&self, _: TypeId) -> Option<Box<dyn Any>> {
None
#[allow(unused_variables)]
pub trait FilterLayer: 'static {
/// Create buffers for this filter
const BUFFERS: bool = true;
#[inline]
/// Check readiness for read operations
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
Poll::Ready(ReadStatus::Ready)
}
fn get_read_buf(&self) -> Option<BytesVec>;
fn release_read_buf(&self, buf: BytesVec);
#[inline]
/// Check readiness for write operations
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
Poll::Ready(WriteStatus::Ready)
}
/// Process read buffer
///
/// Returns tuple (total bytes, new bytes)
fn process_read_buf(&self, io: &IoRef, n: usize) -> sio::Result<(usize, usize)>;
/// Inner filter must process buffer before current.
/// Returns number of new bytes.
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> sio::Result<usize>;
fn get_write_buf(&self) -> Option<BytesVec>;
/// Process write buffer
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> sio::Result<()>;
fn release_write_buf(&self, buf: BytesVec) -> sio::Result<()>;
/// Check readiness for read operations
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
/// Check readiness for write operations
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus>;
/// Query internal filter data
fn query(&self, id: TypeId) -> Option<Box<dyn Any>> {
None
}
/// Gracefully shutdown filter
fn poll_shutdown(&self) -> Poll<sio::Result<()>> {
Poll::Ready(Ok(()))
fn shutdown(&self, buf: &mut WriteBuf<'_>) -> sio::Result<Poll<()>> {
Ok(Poll::Ready(()))
}
}
/// Creates new `Filter` values.
pub trait FilterFactory<F: Filter>: Sized {
pub trait FilterFactory<F>: Sized {
/// The `Filter` value created by this factory
type Filter: Filter;
type Filter: FilterLayer;
/// Errors produced while building a filter.
type Error: fmt::Debug;
/// The future of the `FilterFactory` instance.
type Future: Future<Output = Result<Io<Self::Filter>, Self::Error>>;
type Future: Future<Output = Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
/// Create and return a new filter value asynchronously.
fn create(self, st: Io<F>) -> Self::Future;

View file

@ -1,6 +1,6 @@
use std::ops;
use crate::{Filter, Io};
use crate::{filter::Filter, Io};
/// Sealed filter type
pub struct Sealed(pub(crate) Box<dyn Filter>);

View file

@ -12,62 +12,93 @@ impl ReadContext {
Self(io.clone())
}
#[inline]
/// Return memory pool for this context
pub fn memory_pool(&self) -> PoolRef {
self.0.memory_pool()
}
#[inline]
/// Check readiness for read operations
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.0.filter().poll_read_ready(cx)
}
#[inline]
/// Get read buffer
pub fn get_read_buf(&self) -> BytesVec {
self.0
.0
.read_buf
pub fn with_buf<F>(&self, f: F) -> Poll<()>
where
F: FnOnce(&mut BytesVec, usize, usize) -> Poll<io::Result<()>>,
{
let mut stack = self.0 .0.buffer.borrow_mut();
let mut buf = stack
.last_read_buf()
.take()
.unwrap_or_else(|| self.0.memory_pool().get_read_buf())
}
.unwrap_or_else(|| self.0.memory_pool().get_read_buf());
#[inline]
/// Release read buffer after io read operations
pub fn release_read_buf(&self, buf: BytesVec, nbytes: usize) {
let total = buf.len();
let (hw, lw) = self.0.memory_pool().read_params().unpack();
// call provided callback
let result = f(&mut buf, hw, lw);
// handle buffer changes
if buf.is_empty() {
self.0.memory_pool().release_read_buf(buf);
} else {
self.0 .0.read_buf.set(Some(buf));
let filter = self.0.filter();
match filter.process_read_buf(&self.0, nbytes) {
Ok((total, nbytes)) => {
if nbytes > 0 {
if total > self.0.memory_pool().read_params().high as usize {
let total2 = buf.len();
let nbytes = if total2 > total { total2 - total } else { 0 };
*stack.last_read_buf() = Some(buf);
if nbytes > 0 {
let buf_full = nbytes >= hw;
match self
.0
.filter()
.process_read_buf(&self.0, &mut stack, 0, nbytes)
{
Ok(nbytes) => {
if nbytes > 0 {
if buf_full || stack.first_read_buf_size() >= hw {
log::trace!(
"io read buffer is too large {}, enable read back-pressure",
total2
);
self.0
.0
.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
} else {
self.0 .0.insert_flags(Flags::RD_READY);
}
self.0 .0.dispatch_task.wake();
log::trace!(
"buffer is too large {}, enable read back-pressure",
total
"new {} bytes available, wakeup dispatcher",
nbytes,
);
self.0 .0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
} else if buf_full {
// read task is paused because of read back-pressure
self.0 .0.read_task.wake();
}
}
Err(err) => {
self.0 .0.dispatch_task.wake();
self.0 .0.insert_flags(Flags::RD_READY);
log::trace!("new {} bytes available, wakeup dispatcher", nbytes);
self.0 .0.init_shutdown(Some(err), &self.0);
}
}
Err(err) => {
self.0 .0.dispatch_task.wake();
self.0 .0.insert_flags(Flags::RD_READY);
self.0.want_shutdown(Some(err));
}
}
}
let result = match result {
Poll::Ready(Ok(())) => {
self.0 .0.io_stopped(None);
Poll::Ready(())
}
Poll::Ready(Err(e)) => {
self.0 .0.io_stopped(Some(e));
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
};
drop(stack);
if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.shutdown_filters();
self.0 .0.shutdown_filters(&self.0);
}
result
}
#[inline]
@ -100,7 +131,7 @@ impl WriteContext {
#[inline]
/// Get write buffer
pub fn get_write_buf(&self) -> Option<BytesVec> {
self.0 .0.write_buf.take()
self.0 .0.buffer.borrow_mut().last_write_buf().take()
}
#[inline]
@ -125,11 +156,11 @@ impl WriteContext {
self.0.set_flags(flags);
self.0 .0.dispatch_task.wake();
}
self.0 .0.write_buf.set(Some(buf))
self.0 .0.buffer.borrow_mut().set_last_write_buf(buf);
}
if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.shutdown_filters();
self.0 .0.shutdown_filters(&self.0);
}
Ok(())

View file

@ -344,6 +344,12 @@ impl Drop for IoTest {
_ => (),
}
self.state.set(state);
let guard = self.remote.lock().unwrap();
let mut remote = guard.borrow_mut();
remote.read = IoTestState::Close;
remote.waker.wake();
log::trace!("drop remote socket");
}
}
@ -388,58 +394,58 @@ impl Future for ReadTask {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to terminate");
Poll::Ready(())
}
Poll::Ready(ReadStatus::Ready) => {
let io = &this.io;
let pool = this.state.memory_pool();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
this.state.with_buf(|buf, hw, lw| {
match this.state.poll_ready(cx) {
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to terminate");
Poll::Ready(Ok(()))
}
Poll::Ready(ReadStatus::Ready) => {
let io = &this.io;
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match io.poll_read_buf(cx, &mut buf) {
Poll::Pending => {
log::trace!("no more data in io stream, read: {:?}", new_bytes);
break;
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
this.state.release_read_buf(buf, new_bytes);
this.state.close(None);
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
match io.poll_read_buf(cx, buf) {
Poll::Pending => {
log::trace!(
"no more data in io stream, read: {:?}",
new_bytes
);
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
return Poll::Ready(Ok(()));
} else {
new_bytes += n;
if buf.len() >= hw {
log::trace!(
"high water mark pause reading, read: {:?}",
new_bytes
);
break;
}
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
return Poll::Ready(Err(err));
}
}
}
}
this.state.release_read_buf(buf, new_bytes);
Poll::Pending
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
Poll::Pending => Poll::Pending,
}
})
}
}

View file

@ -3,7 +3,7 @@ use std::marker::PhantomData;
use ntex_service::{fn_service, pipeline_factory, Service, ServiceFactory};
use ntex_util::future::Ready;
use crate::{Filter, FilterFactory, Io, IoBoxed};
use crate::{Filter, FilterFactory, Io, IoBoxed, Layer};
/// Service that converts any Io<F> stream to IoBoxed stream
pub fn seal<F, S, C>(
@ -30,7 +30,6 @@ where
pub fn filter<T, F>(filter: T) -> FilterServiceFactory<T, F>
where
T: FilterFactory<F> + Clone,
F: Filter,
{
FilterServiceFactory {
filter,
@ -46,9 +45,8 @@ pub struct FilterServiceFactory<T, F> {
impl<T, F> ServiceFactory<Io<F>> for FilterServiceFactory<T, F>
where
T: FilterFactory<F> + Clone,
F: Filter,
{
type Response = Io<T::Filter>;
type Response = Io<Layer<T::Filter, F>>;
type Error = T::Error;
type Service = FilterService<T, F>;
type InitError = ();
@ -71,25 +69,28 @@ pub struct FilterService<T, F> {
impl<T, F> Service<Io<F>> for FilterService<T, F>
where
T: FilterFactory<F> + Clone,
F: Filter,
{
type Response = Io<T::Filter>;
type Response = Io<Layer<T::Filter, F>>;
type Error = T::Error;
type Future<'f> = T::Future where T: 'f;
type Future<'f> = T::Future where T: 'f, F: 'f;
#[inline]
fn call(&self, req: Io<F>) -> Self::Future<'_> {
req.add_filter(self.filter.clone())
self.filter.clone().create(req)
}
}
#[cfg(test)]
mod tests {
use ntex_bytes::{Bytes, BytesVec};
use std::io;
use ntex_bytes::Bytes;
use ntex_codec::BytesCodec;
use super::*;
use crate::{filter::NullFilter, testing::IoTest};
use crate::{
buf::Stack, filter::NullFilter, testing::IoTest, FilterLayer, ReadBuf, WriteBuf,
};
#[ntex::test]
async fn test_utils() {
@ -114,16 +115,28 @@ mod tests {
assert_eq!(buf, b"RES".as_ref());
}
#[derive(Copy, Clone, Debug)]
struct NullFilterFactory;
pub(crate) struct TestFilter;
impl<F: Filter> FilterFactory<F> for NullFilterFactory {
type Filter = crate::filter::NullFilter;
impl FilterLayer for TestFilter {
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
Ok(buf.nbytes())
}
fn process_write_buf(&self, _: &mut WriteBuf<'_>) -> io::Result<()> {
Ok(())
}
}
#[derive(Copy, Clone, Debug)]
struct TestFilterFactory;
impl<F: Filter> FilterFactory<F> for TestFilterFactory {
type Filter = TestFilter;
type Error = std::convert::Infallible;
type Future = Ready<Io<Self::Filter>, Self::Error>;
type Future = Ready<Io<Layer<TestFilter, F>>, Self::Error>;
fn create(self, st: Io<F>) -> Self::Future {
st.map_filter(|_| Ok(NullFilter)).into()
Ready::Ok(st.add_filter(TestFilter).into())
}
}
@ -131,7 +144,7 @@ mod tests {
async fn test_utils_filter() {
let (_, server) = IoTest::create();
let svc = pipeline_factory(
filter::<_, crate::filter::Base>(NullFilterFactory)
filter::<_, crate::filter::Base>(TestFilterFactory)
.map_err(|_| ())
.map_init_err(|_| ()),
)
@ -147,8 +160,15 @@ mod tests {
#[ntex::test]
async fn test_null_filter() {
let (_, server) = IoTest::create();
let io = Io::new(server);
let ioref = io.get_ref();
let mut stack = Stack::new();
assert!(NullFilter.query(std::any::TypeId::of::<()>()).is_none());
assert!(NullFilter.poll_shutdown().is_ready());
assert!(NullFilter
.shutdown(&ioref, &mut stack, 0)
.unwrap()
.is_ready());
assert_eq!(
ntex_util::future::poll_fn(|cx| NullFilter.poll_read_ready(cx)).await,
crate::ReadStatus::Terminate
@ -157,16 +177,12 @@ mod tests {
ntex_util::future::poll_fn(|cx| NullFilter.poll_write_ready(cx)).await,
crate::WriteStatus::Terminate
);
assert_eq!(NullFilter.get_read_buf(), None);
assert_eq!(NullFilter.get_write_buf(), None);
assert!(NullFilter.release_write_buf(BytesVec::new()).is_ok());
NullFilter.release_read_buf(BytesVec::new());
let (_, server) = IoTest::create();
let io = Io::new(server);
assert!(NullFilter.process_write_buf(&ioref, &mut stack, 0).is_ok());
assert_eq!(
NullFilter.process_read_buf(&io.get_ref(), 10).unwrap(),
(0, 0)
NullFilter
.process_read_buf(&ioref, &mut stack, 0, 0)
.unwrap(),
(0)
)
}
}

View file

@ -1,5 +1,9 @@
# Changes
## [0.2.1] - 2023-01-23
* Update filter implementation
## [0.2.0] - 2023-01-04
* Release

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "0.2.0"
version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
@ -25,15 +25,15 @@ openssl = ["tls_openssl"]
rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1.18"
ntex-io = "0.2.0"
ntex-bytes = "0.1.19"
ntex-io = "0.2.1"
ntex-util = "0.2.0"
ntex-service = "1.0.0"
log = "0.4"
pin-project-lite = "0.2"
# openssl
tls_openssl = { version="0.10.42", package = "openssl", optional = true }
tls_openssl = { version="0.10", package = "openssl", optional = true }
# rustls
tls_rust = { version = "0.20", package = "rustls", optional = true }

View file

@ -1,31 +1,31 @@
-----BEGIN CERTIFICATE-----
MIIFPjCCAyYCCQDWGwiaSniPcTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV
MIIFPjCCAyYCCQDvLYiYD+jqeTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV
UzELMAkGA1UECAwCQ0ExCzAJBgNVBAcMAlNGMRAwDgYDVQQKDAdDb21wYW55MQww
CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0yMTEyMTgx
NjMwNDlaFw0yMjEyMTgxNjMwNDlaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD
CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xODAxMjUx
NzQ2MDFaFw0xOTAxMjUxNzQ2MDFaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD
QTELMAkGA1UEBwwCU0YxEDAOBgNVBAoMB0NvbXBhbnkxDDAKBgNVBAsMA09yZzEY
MBYGA1UEAwwPd3d3LmV4YW1wbGUuY29tMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A
MIICCgKCAgEAryUL1k7npaMOck9OO+EjzeL0FoysOP5JrgRh+8BoPY7WPyL56oFP
aCYKp2YMucmvFh/VZSyupC75JJNIaW0fvcIe4Euzy2Ex0VukPxYteRicaWRsxSId
o5RNNHd7JOf3ZWIMqkxmDhPNnqSGHcnVs/14+I5IbJCoba+KNElmL9CrL3gQkqNY
Jf2FSIgou5j1OthEdnQpiRxSRLmJ7gXtvpFGgj4AnrHGsMAPHueeop6yOX6egFnw
2cwp98c/0tMOUsXnDU1MTGF11+4UVr043SruZKU7bvhMZRcf4NTR2MNin0b3DYJ+
JbTn+HgPhhhx3mrsWRyCvfP23jzwnV/222o+U46i7tNYYrDN8vXIM17gtIvKrv2F
CLTJE6tsp0xAi6dT+J+AIVqkJntrsxqx2CuOYGOOkPPc4rSf64bwOR1mikdvZCnV
NwGEXcH3nBRFMlk5bByCW0kUy03QNakiUEF+PoFzLrCL+V+21Q6Fd7Jmw06BzVFV
2YtsqFcSo7HXW91XJTDVJCPnrMJOooKQ9Fbq4zbQM0Lv02LyJWyR+0PMBzy4FfkW
ZWz10g3w+CITL/MQ65fsBBc9hRHC3QBWetj3puqM8DlPwqPhgmCA5zo8AWx7CogR
V66ukkeBYXYFHwV5uDJTX91tbwYesOL43rlDT905aV0VbaAyDZflipMCAwEAATAN
BgkqhkiG9w0BAQsFAAOCAgEAWeq502+YKMHrk8YD4L2mzY/AHSEWf6XubMgkNRbh
s72+zJs2SrAzu+y+iv5La4JXOxrWEvZOUCKAK0sRG/+ESQxul5mbyPQLWFJgSqv5
O2RmhQ65l+O6RjPZbHPNJMTLMMlkFrKctgGIg5ysKHWPEZZ7ZlS3maxon+X75/b5
uI3BxBpJTWcg6zOxh0+zIxhesgEbRmaEz6qu3ZSktBeUQFpTElreCcbkntlFbr+9
SiKkaO4l6qEwRDhA595/7/JRZo4R5U1MifU6JhTMOyXTsH3BV1aVeS81/9jGPHl8
kgVxeKSpL/jDwuSJdr+dMxs/TJHV6fsnVewFFFmigLWThYGDnKmXqJQNyt8utRpe
6vvReWSSIece1EdBActy0rtjPaUJNTTdYk1UYo63OIbCguLWQD1XYN1qJg4KWJzB
PjS6KCOLmJvYrAxRMED4XeZ17+PxC3xr2IpAL+loRhZUuxXV4GhccGZ4z89OIdOU
x97x2BjjV5Nnnt6eBfF3vP5sOz31QpAS/8tzdlGD+6Xq2/i1ZKMPrwgs2dhTyah0
kCBfdE88Zew/A79z55IsVNiYJ4MrD8WTFjcM2j8SgI7tg+M+X/unj+wnzYT0L0dg
BEfzPd7zWdDOPInlTV9zUj1WOsLHX9odOh/Jj5X0FV5vZtcyQ0sGJAhdgTaXDvXs
Ing=
MIICCgKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEPn8k1
sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+MIK5U
NLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM54jXy
voLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZWLWr
odGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAkoqND
xdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNliJDmA
CRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6/stI
yFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuDYX2U
UuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nPwPTO
vRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA69un
CEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEAATAN
BgkqhkiG9w0BAQsFAAOCAgEApavsgsn7SpPHfhDSN5iZs1ILZQRewJg0Bty0xPfk
3tynSW6bNH3nSaKbpsdmxxomthNSQgD2heOq1By9YzeOoNR+7Pk3s4FkASnf3ToI
JNTUasBFFfaCG96s4Yvs8KiWS/k84yaWuU8c3Wb1jXs5Rv1qE1Uvuwat1DSGXSoD
JNluuIkCsC4kWkyq5pWCGQrabWPRTWsHwC3PTcwSRBaFgYLJaR72SloHB1ot02zL
d2age9dmFRFLLCBzP+D7RojBvL37qS/HR+rQ4SoQwiVc/JzaeqSe7ZbvEH9sZYEu
ALowJzgbwro7oZflwTWunSeSGDSltkqKjvWvZI61pwfHKDahUTmZ5h2y67FuGEaC
CIOUI8dSVSPKITxaq3JL4ze2e9/0Lt7hj19YK2uUmtMAW5Tirz4Yx5lyGH9U8Wur
y/X8VPxTc4A9TMlJgkyz0hqvhbPOT/zSWB10zXh0glKAsSBryAOEDxV1UygmSir7
YV8Qaq+oyKUTMc1MFq5vZ07M51EPaietn85t8V2Y+k/8XYltRp32NxsypxAJuyxh
g/ko6RVTrWa1sMvz/F9LFqAdKiK5eM96lh9IU4xiLg4ob8aS/GRAA8oIFkZFhLrt
tOwjIUPmEPyHWFi8dLpNuQKYalLYhuwZftG/9xV+wqhKGZO9iPrpHSYBRTap8w2y
1QU=
-----END CERTIFICATE-----

View file

@ -1,51 +1,51 @@
-----BEGIN RSA PRIVATE KEY-----
MIIJKQIBAAKCAgEAryUL1k7npaMOck9OO+EjzeL0FoysOP5JrgRh+8BoPY7WPyL5
6oFPaCYKp2YMucmvFh/VZSyupC75JJNIaW0fvcIe4Euzy2Ex0VukPxYteRicaWRs
xSIdo5RNNHd7JOf3ZWIMqkxmDhPNnqSGHcnVs/14+I5IbJCoba+KNElmL9CrL3gQ
kqNYJf2FSIgou5j1OthEdnQpiRxSRLmJ7gXtvpFGgj4AnrHGsMAPHueeop6yOX6e
gFnw2cwp98c/0tMOUsXnDU1MTGF11+4UVr043SruZKU7bvhMZRcf4NTR2MNin0b3
DYJ+JbTn+HgPhhhx3mrsWRyCvfP23jzwnV/222o+U46i7tNYYrDN8vXIM17gtIvK
rv2FCLTJE6tsp0xAi6dT+J+AIVqkJntrsxqx2CuOYGOOkPPc4rSf64bwOR1mikdv
ZCnVNwGEXcH3nBRFMlk5bByCW0kUy03QNakiUEF+PoFzLrCL+V+21Q6Fd7Jmw06B
zVFV2YtsqFcSo7HXW91XJTDVJCPnrMJOooKQ9Fbq4zbQM0Lv02LyJWyR+0PMBzy4
FfkWZWz10g3w+CITL/MQ65fsBBc9hRHC3QBWetj3puqM8DlPwqPhgmCA5zo8AWx7
CogRV66ukkeBYXYFHwV5uDJTX91tbwYesOL43rlDT905aV0VbaAyDZflipMCAwEA
AQKCAgBoOnqt4a0XNE8PlcRv/A6Loskxdiuzixib133cDOe74nn7frwNY0C3MRRc
BG4ETlLErtMWb53KlS2tJ30LSGaATbqELmjj2oaEGa5H4NHU4+GJErtsIV5UD5hW
ZdhB4U2n5s60tdxx+jT+eNhbd9aWU3yfJkVRXlDtXW64qQmH4P1OtXvfWBfIG/Qq
cuUSpvchOrybZYumTdVjkqrTnHGcW+YC8hT6W79rRhB5issr6ZcUghafOWcMpeQ/
0TJZK0K13ZIfp2WFeuZfRw6Rg/AIJllSScZxxo/oBPfym5P6FGRndxrkzkh19g+q
HQDYA0oYW7clXMMtebbrEIb8kLRdaIHDiwyFXmyywvuAAk0jHbA8snM2dyeJWSRr
WQjvQFccGF4z390ZGUCN0ZeESskndg12r4jYaL/aQ8dQZ1ivS69F8bmbQtQNU2Ej
hscTUzEMOnrBTxvRQTjI9nnrbsbklagKmJHXOc/fj13g6/FkcfmeTrjuB30LxJyH
j+xXAi8AGv/oZRk6s/txas5hXpcFXnQDRobVoJjV8kuomcDTt1j33H+05ACFyvHM
/2jxJ1f3xbFx3fqivL89+Z4r8RYxLoWLg7QuqQLdtRgThEKUG0t3lt59fUo+JVVJ
CgRbj/OM3n5udgiIeBAyMAMZjVPUKhvLIFpiUY2vKnYx/97L0QKCAQEA4QUt3dEh
P0L4eQEAzg/J8JuleH7io5VxoK5c2oulhCdQdRDF5HWSesPKJmCmgJRmIXi7zseo
Sbg7Hd2xt/QnaPhRnASXJOdn7ddtoZ1M6Zb0y+d6mmcG+mK6PshtMCQ5S3Lqhsuh
tYQbwawNlCFzwzCzwGb3aD9lBKQYts7KFrMT3Goexg3Qqv374XGn6Eg1LMhXWYbT
M5gcPOYnOT+RugeaTxMnJ6nr6E7kyrLIS+xASXKwXGxSUsQG9VWH7jDuzzARrPEU
aeyxWdbDkBn2vzW+wDpMPMqzoShZsRC9NnFfncXRZfUC5DJWGzwA/xZaR0ZNNng2
OE7rILyAH/aZSQKCAQEAx0ICGi7y94vn5KWdaNVol3nPid4aMnk4LDcX5m0tiqUG
7LIqnFDOOjEdxTf13n7Cv4gotOQNluOypswSDZI4tI0xQ/dJ8PI+vwmA0oHSzf7U
ZPO2gzIOzububPllQsCrKHN++2SyyNlKyYFu/akmlu6yIN3EMRLqYKvZaIL5z9Lk
pTU7eS0AsXJyqD54zRLFkw6S9omQHJXrEzYAuZI+Ue/Arlgyq95mUMsHYRHgaTq4
GDMDLHNyrdKUhW+ZiZ9dhX+aRghHgNiXDk/Eh2/RZrLhKdVk94dJQbfGu/aiSk71
dXPEAaQ7o1MDwQgu4TsCVCzac/CeqvmcoMFyx3NA+wKCAQEAoLfLR8hsH7wcroiZ
45QBXzo8WLD//WjrDKIdLfdaE+bkn4iIX6HeKpMXGowjwGi9/aA3O/z85RKSHsXO
fp4DXAUofPAGaFRjtcwNwMYSPjEUzWKa/hciM8o6TkdnPWBSD+KXQgnFiVk/Xfge
hrPR9BMgAAdLJIlLBKKUCFXwn3/uaprdOgZ6CPd5ZU+BZvXUDRVW1lnnFc3KNXEJ
iOkvk5iEjYAXkkvadEWNQn2pdBjc3djtwEWaEwVyFt6tROJsX01tAoH6W6G0Fn+/
lHgG9hFUGgZJl44L+MpSLZbQHkehzJWS92ilVQni2HbmG0wC1S+QTJxV1agAZpRc
SvgeCQKCAQB3PnVrnfUhV8Sq/MG63xv8qpUc+KHM2uZW75GKAIRkmGYQeH8vlNwV
zxb104t8X3fEj4Ns3Z2UUyey0iVrobn1sxlshyzk2NPcF5/UWoUBaiNJVuA+m1Jp
V6IP7SBAVnUXfCbd42Fq+T7cYG0/uF6zrJ1FNfIXPC6vM6ij9t3xFVBn3fd9iQUF
LGyZaul4MGe0neAtUh3APae0k3jTlUVeW5B/xaBtYmbwqs/7s2sNDmrlcIHRtDVI
+OCRCjxkM88P+VEl4AaKgRPFKM+ADdbPEvXUxzPpPjkE7yorimmM9rvGUkVWhiZ6
k0+H0ZHckCfQoBcLk1AhGcg2HA7IdZzJAoIBAQDAicb6CWlNdaIcJfADKSNK4+BF
JFbH+lXYrTxVSTV+Ubdi0w8Kwk0bzf20EstJnaOCyLDCjcxafjbmuGBVbw7an0lt
Yxjx0fWXxMfvb9/3vKuJVUySA4iq/zfXKlokRaFoqbdRfod3PVGUsynCV7HmImf3
RZA0WkcSwzbg2E2QNKQ3CPd3cHtPpBX8TwRCotg/R5yCR9lihVfkSyULikwBFvrm
2UKZm4pPESWSfMHBToJoAeO0g67itbwwpNhwvgUdyjaj8u46qyjN1FMx3mBiv7Yq
CIE+H0qNu0jmFhoqPrgxfFrGCi6eDPYjRS86536Nc4m8y24z2hie8JLK8QKQ
MIIJKAIBAAKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEP
n8k1sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+M
IK5UNLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM5
4jXyvoLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZ
WLWrodGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAk
oqNDxdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNli
JDmACRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6
/stIyFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuD
YX2UUuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nP
wPTOvRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA
69unCEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEA
AQKCAgAME3aoeXNCPxMrSri7u4Xnnk71YXl0Tm9vwvjRQlMusXZggP8VKN/KjP0/
9AE/GhmoxqPLrLCZ9ZE1EIjgmZ9Xgde9+C8rTtfCG2RFUL7/5J2p6NonlocmxoJm
YkxYwjP6ce86RTjQWL3RF3s09u0inz9/efJk5O7M6bOWMQ9VZXDlBiRY5BYvbqUR
6FeSzD4MnMbdyMRoVBeXE88gTvZk8xhB6DJnLzYgc0tKiRoeKT0iYv5JZw25VyRM
ycLzfTrFmXCPfB1ylb483d9Ly4fBlM8nkx37PzEnAuukIawDxsPOb9yZC+hfvNJI
7NFiMN+3maEqG2iC00w4Lep4skHY7eHUEUMl+Wjr+koAy2YGLWAwHZQTm7iXn9Ab
L6adL53zyCKelRuEQOzbeosJAqS+5fpMK0ekXyoFIuskj7bWuIoCX7K/kg6q5IW+
vC2FrlsrbQ79GztWLVmHFO1I4J9M5r666YS0qdh8c+2yyRl4FmSiHfGxb3eOKpxQ
b6uI97iZlkxPF9LYUCSc7wq0V2gGz+6LnGvTHlHrOfVXqw/5pLAKhXqxvnroDTwz
0Ay/xFF6ei/NSxBY5t8ztGCBm45wCU3l8pW0X6dXqwUipw5b4MRy1VFRu6rqlmbL
OPSCuLxqyqsigiEYsBgS/icvXz9DWmCQMPd2XM9YhsHvUq+R4QKCAQEA98EuMMXI
6UKIt1kK2t/3OeJRyDd4iv/fCMUAnuPjLBvFE4cXD/SbqCxcQYqb+pue3PYkiTIC
71rN8OQAc5yKhzmmnCE5N26br/0pG4pwEjIr6mt8kZHmemOCNEzvhhT83nfKmV0g
9lNtuGEQMiwmZrpUOF51JOMC39bzcVjYX2Cmvb7cFbIq3lR0zwM+aZpQ4P8LHCIu
bgHmwbdlkLyIULJcQmHIbo6nPFB3ZZE4mqmjwY+rA6Fh9rgBa8OFCfTtrgeYXrNb
IgZQ5U8GoYRPNC2ot0vpTinraboa/cgm6oG4M7FW1POCJTl+/ktHEnKuO5oroSga
/BSg7hCNFVaOhwKCAQEA4Kkys0HtwEbV5mY/NnvUD5KwfXX7BxoXc9lZ6seVoLEc
KjgPYxqYRVrC7dB2YDwwp3qcRTi/uBAgFNm3iYlDzI4xS5SeaudUWjglj7BSgXE2
iOEa7EwcvVPluLaTgiWjlzUKeUCNNHWSeQOt+paBOT+IgwRVemGVpAgkqQzNh/nP
tl3p9aNtgzEm1qVlPclY/XUCtf3bcOR+z1f1b4jBdn0leu5OhnxkC+Htik+2fTXD
jt6JGrMkanN25YzsjnD3Sn+v6SO26H99wnYx5oMSdmb8SlWRrKtfJHnihphjG/YY
l1cyorV6M/asSgXNQfGJm4OuJi0I4/FL2wLUHnU+JwKCAQEAzh4WipcRthYXXcoj
gMKRkMOb3GFh1OpYqJgVExtudNTJmZxq8GhFU51MR27Eo7LycMwKy2UjEfTOnplh
Us2qZiPtW7k8O8S2m6yXlYUQBeNdq9IuuYDTaYD94vsazscJNSAeGodjE+uGvb1q
1wLqE87yoE7dUInYa1cOA3+xy2/CaNuviBFJHtzOrSb6tqqenQEyQf6h9/12+DTW
t5pSIiixHrzxHiFqOoCLRKGToQB+71rSINwTf0nITNpGBWmSj5VcC3VV3TG5/XxI
fPlxV2yhD5WFDPVNGBGvwPDSh4jSMZdZMSNBZCy4XWFNSKjGEWoK4DFYed3DoSt9
5IG1YwKCAQA63ntHl64KJUWlkwNbboU583FF3uWBjee5VqoGKHhf3CkKMxhtGqnt
+oN7t5VdUEhbinhqdx1dyPPvIsHCS3K1pkjqii4cyzNCVNYa2dQ00Qq+QWZBpwwc
3GAkz8rFXsGIPMDa1vxpU6mnBjzPniKMcsZ9tmQDppCEpBGfLpio2eAA5IkK8eEf
cIDB3CM0Vo94EvI76CJZabaE9IJ+0HIJb2+jz9BJ00yQBIqvJIYoNy9gP5Xjpi+T
qV/tdMkD5jwWjHD3AYHLWKUGkNwwkAYFeqT/gX6jpWBP+ZRPOp011X3KInJFSpKU
DT5GQ1Dux7EMTCwVGtXqjO8Ym5wjwwsfAoIBAEcxlhIW1G6BiNfnWbNPWBdh3v/K
5Ln98Rcrz8UIbWyl7qNPjYb13C1KmifVG1Rym9vWMO3KuG5atK3Mz2yLVRtmWAVc
fxzR57zz9MZFDun66xo+Z1wN3fVxQB4CYpOEI4Lb9ioX4v85hm3D6RpFukNtRQEc
Gfr4scTjJX4jFWDp0h6ffMb8mY+quvZoJ0TJqV9L9Yj6Ksdvqez/bdSraev97bHQ
4gbQxaTZ6WjaD4HjpPQefMdWp97Metg0ZQSS8b8EzmNFgyJ3XcjirzwliKTAQtn6
I2sd0NCIooelrKRD8EJoDUwxoOctY7R97wpZ7/wEHU45cBCbRV3H4JILS5c=
-----END RSA PRIVATE KEY-----

View file

@ -27,7 +27,8 @@ async fn main() -> io::Result<()> {
// rustls connector
let connector = connect::rustls::Connector::new(config.clone());
let io = connector.connect("www.rust-lang.org:443").await.unwrap();
//let io = connector.connect("www.rust-lang.org:443").await.unwrap();
let io = connector.connect("127.0.0.1:8443").await.unwrap();
println!("Connected to tls server {:?}", io.query::<PeerAddr>().get());
let result = io
.send(Bytes::from_static(b"GET /\r\n\r\n"), &codec::BytesCodec)

View file

@ -1,7 +1,7 @@
use std::task::{Context, Poll};
use std::{error::Error, future::Future, marker::PhantomData, pin::Pin};
use ntex_io::{Filter, FilterFactory, Io};
use ntex_io::{Filter, FilterFactory, Io, Layer};
use ntex_service::{Service, ServiceFactory};
use ntex_util::{future::Ready, time::Millis};
use tls_openssl::ssl::SslAcceptor;
@ -53,7 +53,7 @@ impl<F> Clone for Acceptor<F> {
}
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
type Response = Io<SslFilter<F>>;
type Response = Io<Layer<SslFilter, F>>;
type Error = Box<dyn Error>;
type Service = AcceptorService<F>;
type InitError = ();
@ -81,7 +81,7 @@ pub struct AcceptorService<F> {
}
impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
type Response = Io<SslFilter<F>>;
type Response = Io<Layer<SslFilter, F>>;
type Error = Box<dyn Error>;
type Future<'f> = AcceptorServiceResponse<F>;
@ -115,7 +115,7 @@ pin_project_lite::pin_project! {
}
impl<F: Filter> Future for AcceptorServiceResponse<F> {
type Output = Result<Io<SslFilter<F>>, Box<dyn Error>>;
type Output = Result<Io<Layer<SslFilter, F>>, Box<dyn Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().fut.poll(cx)

View file

@ -1,10 +1,9 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::cell::{Cell, RefCell};
use std::{any, cmp, error::Error, io, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesVec, PoolRef};
use ntex_io::{types, Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, future::BoxFuture, ready, time, time::Millis};
use tls_openssl::ssl::{self, NameType, SslStream};
use tls_openssl::x509::X509;
@ -23,28 +22,27 @@ pub struct PeerCert(pub X509);
pub struct PeerCertChain(pub Vec<X509>);
/// An implementation of SSL streams
pub struct SslFilter<F = Base> {
inner: RefCell<SslStream<IoInner<F>>>,
pub struct SslFilter {
inner: RefCell<SslStream<IoInner>>,
pool: PoolRef,
handshake: Cell<bool>,
read_buf: Cell<Option<BytesVec>>,
}
struct IoInner<F> {
inner: F,
struct IoInner {
inner_read_buf: Option<BytesVec>,
inner_write_buf: Option<BytesVec>,
pool: PoolRef,
write_buf: Option<BytesVec>,
}
impl<F: Filter> io::Read for IoInner<F> {
impl io::Read for IoInner {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut buf) = self.inner.get_read_buf() {
if let Some(mut buf) = self.inner_read_buf.take() {
if buf.is_empty() {
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let len = cmp::min(buf.len(), dst.len());
dst[..len].copy_from_slice(&buf.split_to(len));
self.inner.release_read_buf(buf);
self.inner_read_buf = Some(buf);
Ok(len)
}
} else {
@ -53,16 +51,16 @@ impl<F: Filter> io::Read for IoInner<F> {
}
}
impl<F: Filter> io::Write for IoInner<F> {
impl io::Write for IoInner {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.inner.get_write_buf() {
let mut buf = if let Some(mut buf) = self.inner_write_buf.take() {
buf.reserve(src.len());
buf
} else {
BytesVec::with_capacity_in(src.len(), self.pool)
};
buf.extend_from_slice(src);
self.inner.release_write_buf(buf)?;
self.inner_write_buf = Some(buf);
Ok(src.len())
}
@ -71,7 +69,37 @@ impl<F: Filter> io::Write for IoInner<F> {
}
}
impl<F: Filter> Filter for SslFilter<F> {
impl SslFilter {
fn with_buffers<F, R>(&self, buf: &mut WriteBuf<'_>, f: F) -> R
where
F: FnOnce() -> R,
{
self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().inner_read_buf =
buf.with_read_buf(|b| b.take_src());
let result = f();
buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take());
buf.with_read_buf(|b| {
b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take())
});
result
}
fn set_buffers(&self, buf: &mut WriteBuf<'_>) {
self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().inner_read_buf =
buf.with_read_buf(|b| b.take_src());
}
fn unset_buffers(&self, buf: &mut WriteBuf<'_>) {
buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take());
buf.with_read_buf(|b| {
b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take())
});
}
}
impl FilterLayer for SslFilter {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2";
@ -116,86 +144,37 @@ impl<F: Filter> Filter for SslFilter<F> {
None
}
} else {
self.inner.borrow().get_ref().inner.query(id)
None
}
}
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
let ssl_result = self.inner.borrow_mut().shutdown();
fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result<Poll<()>> {
let ssl_result = self.with_buffers(buf, || self.inner.borrow_mut().shutdown());
match ssl_result {
Ok(ssl::ShutdownResult::Sent) => Poll::Pending,
Ok(ssl::ShutdownResult::Received) => {
self.inner.borrow().get_ref().inner.poll_shutdown()
}
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
self.inner.borrow().get_ref().inner.poll_shutdown()
}
Ok(ssl::ShutdownResult::Sent) => Ok(Poll::Pending),
Ok(ssl::ShutdownResult::Received) => Ok(Poll::Ready(())),
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(Poll::Ready(())),
Err(ref e)
if e.code() == ssl::ErrorCode::WANT_READ
|| e.code() == ssl::ErrorCode::WANT_WRITE =>
{
Poll::Pending
Ok(Poll::Pending)
}
Err(e) => Poll::Ready(Err(e
Err(e) => Err(e
.into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))),
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))),
}
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.borrow().get_ref().inner.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.borrow().get_ref().inner.poll_write_ready(cx)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
self.read_buf.take()
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
self.inner.borrow_mut().get_mut().write_buf.take()
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
self.read_buf.set(Some(buf));
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
// ask inner filter to process read buf
match self
.inner
.borrow_mut()
.get_ref()
.inner
.process_read_buf(io, nbytes)
{
Err(err) => io.want_shutdown(Some(err)),
Ok((n, 0)) => return Ok((n, 0)),
Ok((_, _)) => (),
}
// get processed buffer
let mut dst = if let Some(dst) = self.get_read_buf() {
dst
} else {
self.pool.get_read_buf()
};
let (hw, lw) = self.pool.read_params().unpack();
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
buf.with_write_buf(|b| self.set_buffers(b));
let dst = buf.get_dst();
let mut new_bytes = usize::from(self.handshake.get());
loop {
// make sure we've got room
let remaining = dst.remaining_mut();
if remaining < lw {
dst.reserve(hw - remaining);
}
self.pool.resize_read_buf(dst);
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
@ -209,43 +188,61 @@ impl<F: Filter> Filter for SslFilter<F> {
if e.code() == ssl::ErrorCode::WANT_READ
|| e.code() == ssl::ErrorCode::WANT_WRITE =>
{
Ok((dst.len(), new_bytes))
Ok(new_bytes)
}
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
io.want_shutdown(None);
Ok((dst.len(), new_bytes))
buf.want_shutdown();
Ok(new_bytes)
}
Err(e) => {
log::trace!("SSL Error: {:?}", e);
Err(map_to_ioerr(e))
}
};
self.release_read_buf(dst);
buf.with_write_buf(|b| self.unset_buffers(b));
return result;
}
}
fn release_write_buf(&self, mut buf: BytesVec) -> Result<(), io::Error> {
loop {
if buf.is_empty() {
return Ok(());
}
let ssl_result = self.inner.borrow_mut().ssl_write(&buf);
match ssl_result {
Ok(v) => {
buf.split_to(v);
continue;
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
if let Some(mut src) = buf.take_src() {
self.set_buffers(buf);
loop {
if src.is_empty() {
self.unset_buffers(buf);
return Ok(());
}
Err(e) => {
if !buf.is_empty() {
self.inner.borrow_mut().get_mut().write_buf = Some(buf);
let ssl_result = self.inner.borrow_mut().ssl_write(&src);
match ssl_result {
Ok(v) => {
src.split_to(v);
continue;
}
Err(e) => {
if !src.is_empty() {
buf.set_src(Some(src));
}
self.unset_buffers(buf);
return match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => {
buf.set_dst(
self.inner
.borrow_mut()
.get_mut()
.inner_write_buf
.take(),
);
Ok(())
}
_ => Err(map_to_ioerr(e)),
};
}
return match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()),
_ => Err(map_to_ioerr(e)),
};
}
}
} else {
Ok(())
}
}
}
@ -283,42 +280,47 @@ impl Clone for SslAcceptor {
}
impl<F: Filter> FilterFactory<F> for SslAcceptor {
type Filter = SslFilter<F>;
type Filter = SslFilter;
type Error = Box<dyn Error>;
type Future = BoxFuture<'static, Result<Io<Self::Filter>, Self::Error>>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
fn create(self, io: Io<F>) -> Self::Future {
let timeout = self.timeout;
let ctx_result = ssl::Ssl::new(self.acceptor.context());
Box::pin(async move {
time::timeout(timeout, async {
let ssl = ctx_result.map_err(map_to_ioerr)?;
let pool = st.memory_pool();
let st = st.map_filter(|inner: F| {
let inner = IoInner {
pool,
inner,
write_buf: None,
};
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
Ok::<_, Box<dyn Error>>(SslFilter {
pool,
read_buf: Cell::new(None),
handshake: Cell::new(true),
inner: RefCell::new(ssl_stream),
})
})?;
let inner = IoInner {
pool: io.memory_pool(),
inner_read_buf: None,
inner_write_buf: None,
};
let filter = SslFilter {
pool: io.memory_pool(),
handshake: Cell::new(true),
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
let io = io.add_filter(filter);
poll_fn(|cx| {
handle_result(st.filter().inner.borrow_mut().accept(), &st, cx)
let result = io
.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
})
.map_err(|err| {
let err: Box<dyn Error> =
io::Error::new(io::ErrorKind::Other, err).into();
err
})?;
handle_result(result, &io, cx)
})
.await?;
st.filter().handshake.set(false);
Ok(st)
io.filter().handshake.set(false);
Ok(io)
})
.await
.map_err(|_| {
@ -341,35 +343,42 @@ impl SslConnector {
}
impl<F: Filter> FilterFactory<F> for SslConnector {
type Filter = SslFilter<F>;
type Filter = SslFilter;
type Error = Box<dyn Error>;
type Future = BoxFuture<'static, Result<Io<Self::Filter>, Self::Error>>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
fn create(self, io: Io<F>) -> Self::Future {
Box::pin(async move {
let ssl = self.ssl;
let pool = st.memory_pool();
let st = st.map_filter(|inner: F| {
let inner = IoInner {
pool,
inner,
write_buf: None,
};
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
let inner = IoInner {
pool: io.memory_pool(),
inner_read_buf: None,
inner_write_buf: None,
};
let filter = SslFilter {
pool: io.memory_pool(),
handshake: Cell::new(true),
inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?),
};
let io = io.add_filter(filter);
Ok::<_, Box<dyn Error>>(SslFilter {
pool,
read_buf: Cell::new(None),
handshake: Cell::new(true),
inner: RefCell::new(ssl_stream),
})
})?;
poll_fn(|cx| {
let result = io
.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
})
.map_err(|err| {
let err: Box<dyn Error> =
io::Error::new(io::ErrorKind::Other, err).into();
err
})?;
handle_result(result, &io, cx)
})
.await?;
poll_fn(|cx| handle_result(st.filter().inner.borrow_mut().connect(), &st, cx))
.await?;
Ok(st)
io.filter().handshake.set(false);
Ok(io)
})
}
}

View file

@ -3,7 +3,7 @@ use std::{future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
use tls_rust::ServerConfig;
use ntex_io::{Filter, FilterFactory, Io};
use ntex_io::{Filter, FilterFactory, Io, Layer};
use ntex_service::{Service, ServiceFactory};
use ntex_util::{future::Ready, time::Millis};
@ -52,7 +52,7 @@ impl<F> Clone for Acceptor<F> {
}
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
type Response = Io<TlsFilter<F>>;
type Response = Io<Layer<TlsFilter, F>>;
type Error = io::Error;
type Service = AcceptorService<F>;
@ -79,7 +79,7 @@ pub struct AcceptorService<F> {
}
impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
type Response = Io<TlsFilter<F>>;
type Response = Io<Layer<TlsFilter, F>>;
type Error = io::Error;
type Future<'f> = AcceptorServiceFut<F>;
@ -113,7 +113,7 @@ pin_project_lite::pin_project! {
}
impl<F: Filter> Future for AcceptorServiceFut<F> {
type Output = Result<Io<TlsFilter<F>>, io::Error>;
type Output = Result<Io<Layer<TlsFilter, F>>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().fut.poll(cx)

View file

@ -1,9 +1,9 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite};
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Context, task::Poll};
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll};
use ntex_bytes::{BufMut, BytesVec};
use ntex_io::{types, Filter, Io, IoRef, ReadStatus, WriteStatus};
use ntex_bytes::BufMut;
use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, ready};
use tls_rust::{ClientConfig, ClientConnection, ServerName};
@ -12,12 +12,12 @@ use crate::rustls::{IoInner, TlsFilter, Wrapper};
use super::{PeerCert, PeerCertChain};
/// An implementation of SSL streams
pub struct TlsClientFilter<F> {
inner: IoInner<F>,
pub struct TlsClientFilter {
inner: IoInner,
session: RefCell<ClientConnection>,
}
impl<F: Filter> Filter for TlsClientFilter<F> {
impl FilterLayer for TlsClientFilter {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2";
@ -52,71 +52,19 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
None
}
} else {
self.inner.filter.query(id)
None
}
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.filter.poll_shutdown()
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.filter.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.filter.poll_write_ready(cx)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
self.inner.read_buf.take()
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
self.inner.write_buf.take()
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
self.inner.read_buf.set(Some(buf));
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
let mut session = self.session.borrow_mut();
// ask inner filter to process read buf
match self.inner.filter.process_read_buf(io, nbytes) {
Err(err) => io.want_shutdown(Some(err)),
Ok((_, 0)) => return Ok((0, 0)),
Ok(_) => (),
}
// get processed buffer
let mut dst = if let Some(dst) = self.inner.read_buf.take() {
dst
} else {
self.inner.pool.get_read_buf()
};
let (hw, lw) = self.inner.pool.read_params().unpack();
let mut src = if let Some(src) = self.inner.filter.get_read_buf() {
src
} else {
return Ok((0, 0));
};
let (src, dst) = buf.get_pair();
let mut new_bytes = usize::from(self.inner.handshake.get());
loop {
// make sure we've got room
let remaining = dst.remaining_mut();
if remaining < lw {
dst.reserve(hw - remaining);
}
self.inner.pool.resize_read_buf(dst);
let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?;
@ -138,73 +86,74 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
}
}
let dst_len = dst.len();
self.inner.read_buf.set(Some(dst));
self.inner.filter.release_read_buf(src);
Ok((dst_len, new_bytes))
Ok(new_bytes)
}
fn release_write_buf(&self, mut src: BytesVec) -> Result<(), io::Error> {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner);
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
if let Some(mut src) = buf.take_src() {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
loop {
if !src.is_empty() {
let n = session.writer().write(&src)?;
src.split_to(n);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
}
}
loop {
if !src.is_empty() {
let n = session.writer().write(&src)?;
src.split_to(n);
}
let n = session.write_tls(&mut io)?;
if n == 0 {
break;
buf.set_src(Some(src));
}
Ok(())
} else {
Ok(())
}
if !src.is_empty() {
self.inner.write_buf.set(Some(src));
}
Ok(())
}
}
impl<F: Filter> TlsClientFilter<F> {
pub(crate) async fn create(
impl TlsClientFilter {
pub(crate) async fn create<F: Filter>(
io: Io<F>,
cfg: Arc<ClientConfig>,
domain: ServerName,
) -> Result<Io<TlsFilter<F>>, io::Error> {
let pool = io.memory_pool();
let session = match ClientConnection::new(cfg, domain) {
Ok(session) => session,
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
};
let io = io.map_filter(|filter: F| {
let inner = IoInner {
pool,
filter,
read_buf: Cell::new(None),
write_buf: Cell::new(None),
) -> Result<Io<Layer<TlsFilter, F>>, io::Error> {
let session = ClientConnection::new(cfg, domain)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let filter = TlsFilter::new_client(TlsClientFilter {
inner: IoInner {
pool: io.memory_pool(),
handshake: Cell::new(true),
};
Ok::<_, io::Error>(TlsFilter::new_client(TlsClientFilter {
inner,
session: RefCell::new(session),
}))
})?;
},
session: RefCell::new(session),
});
let io = io.add_filter(filter);
let filter = io.filter();
loop {
let (result, wants_read, handshaking) = {
let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.client().session.borrow_mut();
let mut wrp = Wrapper(&filter.client().inner);
(
let mut wrp = Wrapper(&filter.client().inner, buf);
let mut result = (
session.complete_io(&mut wrp),
session.wants_read(),
session.is_handshaking(),
)
};
);
while session.wants_write() {
result.0 = session.complete_io(&mut wrp);
if result.0.is_err() {
break;
}
}
result
})?;
match result {
Ok(_) => {
filter.client().inner.handshake.set(false);

View file

@ -1,11 +1,13 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::{any, cmp, future::Future, io, pin::Pin, task::Context, task::Poll};
use std::{cell::Cell, sync::Arc};
use std::{any, cell::Cell, cmp, io, sync::Arc, task::Context, task::Poll};
use ntex_bytes::{BytesVec, PoolRef};
use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::time::Millis;
use ntex_bytes::PoolRef;
use ntex_io::{
Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf,
WriteStatus,
};
use ntex_util::{future::BoxFuture, time::Millis};
use tls_rust::{Certificate, ClientConfig, ServerConfig, ServerName};
mod accept;
@ -25,33 +27,33 @@ pub struct PeerCert(pub Certificate);
pub struct PeerCertChain(pub Vec<Certificate>);
/// An implementation of SSL streams
pub struct TlsFilter<F = Base> {
inner: InnerTlsFilter<F>,
pub struct TlsFilter {
inner: InnerTlsFilter,
}
enum InnerTlsFilter<F> {
Server(TlsServerFilter<F>),
Client(TlsClientFilter<F>),
enum InnerTlsFilter {
Server(TlsServerFilter),
Client(TlsClientFilter),
}
impl<F> TlsFilter<F> {
fn new_server(server: TlsServerFilter<F>) -> Self {
impl TlsFilter {
fn new_server(server: TlsServerFilter) -> Self {
TlsFilter {
inner: InnerTlsFilter::Server(server),
}
}
fn new_client(client: TlsClientFilter<F>) -> Self {
fn new_client(client: TlsClientFilter) -> Self {
TlsFilter {
inner: InnerTlsFilter::Client(client),
}
}
fn server(&self) -> &TlsServerFilter<F> {
fn server(&self) -> &TlsServerFilter {
match self.inner {
InnerTlsFilter::Server(ref server) => server,
_ => unreachable!(),
}
}
fn client(&self) -> &TlsClientFilter<F> {
fn client(&self) -> &TlsClientFilter {
match self.inner {
InnerTlsFilter::Client(ref server) => server,
_ => unreachable!(),
@ -59,7 +61,7 @@ impl<F> TlsFilter<F> {
}
}
impl<F: Filter> Filter for TlsFilter<F> {
impl FilterLayer for TlsFilter {
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
match self.inner {
@ -69,10 +71,10 @@ impl<F: Filter> Filter for TlsFilter<F> {
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result<Poll<()>> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.poll_shutdown(),
InnerTlsFilter::Client(ref f) => f.poll_shutdown(),
InnerTlsFilter::Server(ref f) => f.shutdown(buf),
InnerTlsFilter::Client(ref f) => f.shutdown(buf),
}
}
@ -93,42 +95,18 @@ impl<F: Filter> Filter for TlsFilter<F> {
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.get_read_buf(),
InnerTlsFilter::Client(ref f) => f.get_read_buf(),
InnerTlsFilter::Server(ref f) => f.process_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_read_buf(buf),
}
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.get_write_buf(),
InnerTlsFilter::Client(ref f) => f.get_write_buf(),
}
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.release_read_buf(buf),
}
}
#[inline]
fn process_read_buf(&self, io: &IoRef, nb: usize) -> io::Result<(usize, usize)> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_read_buf(io, nb),
InnerTlsFilter::Client(ref f) => f.process_read_buf(io, nb),
}
}
#[inline]
fn release_write_buf(&self, src: BytesVec) -> Result<(), io::Error> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_write_buf(src),
InnerTlsFilter::Client(ref f) => f.release_write_buf(src),
InnerTlsFilter::Server(ref f) => f.process_write_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_write_buf(buf),
}
}
}
@ -172,10 +150,10 @@ impl Clone for TlsAcceptor {
}
impl<F: Filter> FilterFactory<F> for TlsAcceptor {
type Filter = TlsFilter<F>;
type Filter = TlsFilter;
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, io::Error>>>>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg.clone();
@ -227,10 +205,10 @@ impl Clone for TlsConnectorConfigured {
}
impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
type Filter = TlsFilter<F>;
type Filter = TlsFilter;
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, io::Error>>>>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg;
@ -240,44 +218,31 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
}
}
pub(crate) struct IoInner<F> {
filter: F,
pub(crate) struct IoInner {
pool: PoolRef,
read_buf: Cell<Option<BytesVec>>,
write_buf: Cell<Option<BytesVec>>,
handshake: Cell<bool>,
}
pub(crate) struct Wrapper<'a, F>(&'a IoInner<F>);
pub(crate) struct Wrapper<'a, 'b>(&'a IoInner, &'a mut WriteBuf<'b>);
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut read_buf) = self.0.filter.get_read_buf() {
self.1.with_read_buf(|buf| {
let read_buf = buf.get_src();
let len = cmp::min(read_buf.len(), dst.len());
let result = if len > 0 {
if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
};
self.0.filter.release_read_buf(read_buf);
result
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
}
})
}
}
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
impl<'a, 'b> io::Write for Wrapper<'a, 'b> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.0.filter.get_write_buf() {
buf.reserve(src.len());
buf
} else {
BytesVec::with_capacity_in(src.len(), self.0.pool)
};
buf.extend_from_slice(src);
self.0.filter.release_write_buf(buf)?;
self.1.get_dst().extend_from_slice(src);
Ok(src.len())
}

View file

@ -1,9 +1,9 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite};
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Context, task::Poll};
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll};
use ntex_bytes::{BufMut, BytesVec};
use ntex_io::{types, Filter, Io, IoRef, ReadStatus, WriteStatus};
use ntex_bytes::BufMut;
use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, ready, time, time::Millis};
use tls_rust::{ServerConfig, ServerConnection};
@ -13,12 +13,12 @@ use crate::Servername;
use super::{PeerCert, PeerCertChain};
/// An implementation of SSL streams
pub struct TlsServerFilter<F> {
inner: IoInner<F>,
pub struct TlsServerFilter {
inner: IoInner,
session: RefCell<ServerConnection>,
}
impl<F: Filter> Filter for TlsServerFilter<F> {
impl FilterLayer for TlsServerFilter {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2";
@ -59,71 +59,19 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
None
}
} else {
self.inner.filter.query(id)
None
}
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.filter.poll_shutdown()
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.filter.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.filter.poll_write_ready(cx)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
self.inner.read_buf.take()
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
self.inner.write_buf.take()
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
self.inner.read_buf.set(Some(buf));
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
let mut session = self.session.borrow_mut();
// ask inner filter to process read buf
match self.inner.filter.process_read_buf(io, nbytes) {
Err(err) => io.want_shutdown(Some(err)),
Ok((_, 0)) => return Ok((0, 0)),
Ok(_) => (),
}
// get processed buffer
let mut dst = if let Some(dst) = self.inner.read_buf.take() {
dst
} else {
self.inner.pool.get_read_buf()
};
let (hw, lw) = self.inner.pool.read_params().unpack();
let mut src = if let Some(src) = self.inner.filter.get_read_buf() {
src
} else {
return Ok((0, 0));
};
let (src, dst) = buf.get_pair();
let mut new_bytes = usize::from(self.inner.handshake.get());
loop {
// make sure we've got room
let remaining = dst.remaining_mut();
if remaining < lw {
dst.reserve(hw - remaining);
}
self.inner.pool.resize_read_buf(dst);
let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?;
@ -145,73 +93,73 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
}
}
let dst_len = dst.len();
self.inner.read_buf.set(Some(dst));
self.inner.filter.release_read_buf(src);
Ok((dst_len, new_bytes))
Ok(new_bytes)
}
fn release_write_buf(&self, mut src: BytesVec) -> Result<(), io::Error> {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner);
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
if let Some(mut src) = buf.take_src() {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
loop {
if !src.is_empty() {
let n = session.writer().write(&src)?;
src.split_to(n);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
}
}
loop {
if !src.is_empty() {
let n = session.writer().write(&src)?;
src.split_to(n);
buf.set_src(Some(src));
}
let n = session.write_tls(&mut io)?;
if n == 0 {
break;
}
}
if !src.is_empty() {
self.inner.write_buf.set(Some(src));
}
Ok(())
}
}
impl<F: Filter> TlsServerFilter<F> {
pub(crate) async fn create(
impl TlsServerFilter {
pub(crate) async fn create<F: Filter>(
io: Io<F>,
cfg: Arc<ServerConfig>,
timeout: Millis,
) -> Result<Io<TlsFilter<F>>, io::Error> {
) -> Result<Io<Layer<TlsFilter, F>>, io::Error> {
time::timeout(timeout, async {
let pool = io.memory_pool();
let session = match ServerConnection::new(cfg) {
Ok(session) => session,
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
};
let io = io.map_filter(|filter: F| {
let inner = IoInner {
pool,
filter,
read_buf: Cell::new(None),
write_buf: Cell::new(None),
let session = ServerConnection::new(cfg)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let filter = TlsFilter::new_server(TlsServerFilter {
session: RefCell::new(session),
inner: IoInner {
pool: io.memory_pool(),
handshake: Cell::new(true),
};
Ok::<_, io::Error>(TlsFilter::new_server(TlsServerFilter {
inner,
session: RefCell::new(session),
}))
})?;
},
});
let io = io.add_filter(filter);
let filter = io.filter();
loop {
let (result, wants_read, handshaking) = {
let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.server().session.borrow_mut();
let mut wrp = Wrapper(&filter.server().inner);
(
let mut wrp = Wrapper(&filter.server().inner, buf);
let mut result = (
session.complete_io(&mut wrp),
session.wants_read(),
session.is_handshaking(),
)
};
);
while session.wants_write() {
result.0 = session.complete_io(&mut wrp);
if result.0.is_err() {
break;
}
}
result
})?;
match result {
Ok(_) => {
filter.server().inner.handshake.set(false);

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tokio"
version = "0.2.0"
version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "tokio intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"]
@ -16,8 +16,8 @@ name = "ntex_tokio"
path = "src/lib.rs"
[dependencies]
ntex-bytes = "0.1.18"
ntex-io = "0.2.0"
ntex-bytes = "0.1.19"
ntex-io = "0.2.1"
ntex-util = "0.2.0"
log = "0.4"
pin-project-lite = "0.2"

View file

@ -54,73 +54,42 @@ impl Future for ReadTask {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
loop {
this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
let mut io = this.io.borrow_mut();
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => {
pending = true;
break;
}
return match poll_read_buf(Pin::new(&mut *io), cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("tokio stream is disconnected");
close = true;
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
new_bytes += n;
if new_bytes <= hw {
continue;
}
Poll::Pending
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
drop(io);
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
Poll::Ready(Err(err))
}
}
};
}
drop(io);
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
Poll::Ready(Ok(()))
}
}
}
})
}
}
@ -269,14 +238,14 @@ impl Future for WriteTask {
if read_buf.filled().is_empty() =>
{
this.state.close(None);
log::trace!("write task is stopped");
log::trace!("tokio write task is stopped");
return Poll::Ready(());
}
Poll::Pending => {
*count += read_buf.filled().len() as u16;
if *count > 4096 {
log::trace!(
"write task is stopped, too much input"
"tokio write task is stopped, too much input"
);
this.state.close(None);
return Poll::Ready(());
@ -344,7 +313,7 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
}
}
}
log::trace!("flushed {} bytes", written);
// log::trace!("flushed {} bytes", written);
// remove written data
let result = if written == len {
@ -501,18 +470,11 @@ mod unixstream {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
loop {
this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
let mut io = this.io.borrow_mut();
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
@ -520,54 +482,31 @@ mod unixstream {
buf.reserve(hw - remaining);
}
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => {
pending = true;
break;
}
return match poll_read_buf(Pin::new(&mut *io), cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("unix stream is disconnected");
close = true;
log::trace!("tokio unix stream is disconnected");
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
new_bytes += n;
if new_bytes <= hw {
continue;
}
Poll::Pending
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
drop(io);
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
log::trace!("unix stream read task failed {:?}", err);
Poll::Ready(Err(err))
}
}
};
}
drop(io);
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
Poll::Ready(Ok(()))
}
}
}
})
}
}
@ -735,10 +674,6 @@ pub fn poll_read_buf<T: AsyncRead>(
cx: &mut Context<'_>,
buf: &mut BytesVec,
) -> Poll<io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let n = {
let dst =
unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit<u8>]) };

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.6.0"
version = "0.6.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"
@ -49,20 +49,20 @@ async-std = ["ntex-rt/async-std", "ntex-async-std", "ntex-connect/async-std"]
[dependencies]
ntex-codec = "0.6.2"
ntex-connect = "0.2.0"
ntex-connect = "0.2.1"
ntex-http = "0.1.9"
ntex-router = "0.5.1"
ntex-service = "1.0.0"
ntex-macros = "0.1.3"
ntex-util = "0.2.0"
ntex-bytes = "0.1.18"
ntex-bytes = "0.1.19"
ntex-h2 = "0.2.0"
ntex-rt = "0.4.7"
ntex-io = "0.2.0"
ntex-tls = "0.2.0"
ntex-tokio = { version = "0.2.0", optional = true }
ntex-glommio = { version = "0.2.0", optional = true }
ntex-async-std = { version = "0.2.0", optional = true }
ntex-io = "0.2.1"
ntex-tls = "0.2.1"
ntex-tokio = { version = "0.2.1", optional = true }
ntex-glommio = { version = "0.2.1", optional = true }
ntex-async-std = { version = "0.2.1", optional = true }
async-oneshot = "0.5.0"
async-channel = "1.8.0"
@ -88,7 +88,7 @@ percent-encoding = "2.1"
serde_json = "1.0"
serde_urlencoded = "0.7"
url-pkg = { version = "2.1", package = "url", optional = true }
coo-kie = { version = "0.16", package = "cookie", optional = true }
coo-kie = { version = "0.17", package = "cookie", optional = true }
# openssl
tls-openssl = { version="0.10", package = "openssl", optional = true }

View file

@ -206,6 +206,7 @@ struct H2ClientInner {
streams: RefCell<HashMap<frame::StreamId, StreamInfo>>,
}
#[derive(Debug)]
struct StreamInfo {
tx: Option<oneshot::Sender<Result<(ResponseHead, Payload), SendRequestError>>>,
stream: Option<h2::Stream>,

View file

@ -10,7 +10,7 @@
//!
//! let response = client.get("http://www.rust-lang.org") // <- Create request builder
//! .header("User-Agent", "ntex::web")
//! .send() // <- Send http request
//! .send() // <- Send http request
//! .await;
//!
//! println!("Response: {:?}", response);

View file

@ -626,7 +626,6 @@ mod tests {
#[crate::rt_test]
async fn test_basics() {
env_logger::init();
let store = Rc::new(RefCell::new(Vec::new()));
let store2 = store.clone();

View file

@ -146,7 +146,7 @@ const DATE_VALUE_DEFAULT: [u8; DATE_VALUE_LENGTH_HDR] = [
b'0', b'0', b'0', b'0', b'0', b'0', b'0', b'\r', b'\n', b'\r', b'\n',
];
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct DateService(Rc<DateServiceInner>);
impl Default for DateService {
@ -155,6 +155,7 @@ impl Default for DateService {
}
}
#[derive(Debug)]
struct DateServiceInner {
current: Cell<bool>,
current_time: Cell<time::Instant>,

View file

@ -21,16 +21,19 @@ bitflags! {
}
}
#[derive(Debug)]
/// HTTP/1 Codec
pub struct ClientCodec {
inner: ClientCodecInner,
}
#[derive(Debug)]
/// HTTP/1 Payload Codec
pub struct ClientPayloadCodec {
inner: ClientCodecInner,
}
#[derive(Debug)]
struct ClientCodecInner {
timer: DateService,
decoder: decoder::MessageDecoder<ResponseHead>,

View file

@ -14,6 +14,7 @@ use super::MAX_BUFFER_SIZE;
const MAX_HEADERS: usize = 96;
#[derive(Debug)]
/// Incoming messagd decoder
pub(super) struct MessageDecoder<T: MessageType>(PhantomData<T>);

View file

@ -1015,8 +1015,8 @@ mod tests {
}
#[crate::rt_test]
/// /// h1 dispatcher still processes all incoming requests
/// /// but it does not write any data to socket
/// h1 dispatcher still processes all incoming requests
/// but it does not write any data to socket
async fn test_write_disconnected() {
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
@ -1039,6 +1039,7 @@ mod tests {
assert_eq!(num.load(Ordering::Relaxed), 1);
}
/// max http message size is 32k (no payload)
#[crate::rt_test]
async fn test_read_large_message() {
let (client, server) = Io::create();

View file

@ -3,8 +3,7 @@ use std::{cell::RefCell, error::Error, fmt, marker, rc::Rc, task};
use crate::http::body::MessageBody;
use crate::http::config::{DispatcherConfig, OnRequest, ServiceConfig};
use crate::http::error::{DispatchError, ResponseError};
use crate::http::request::Request;
use crate::http::response::Response;
use crate::http::{request::Request, response::Response};
use crate::io::{types, Filter, Io};
use crate::service::{IntoServiceFactory, Service, ServiceFactory};
use crate::{time::Millis, util::BoxFuture};
@ -56,9 +55,9 @@ mod openssl {
use tls_openssl::ssl::SslAcceptor;
use super::*;
use crate::{server::SslError, service::pipeline_factory};
use crate::{io::Layer, server::SslError, service::pipeline_factory};
impl<F, S, B, X, U> H1Service<SslFilter<F>, S, B, X, U>
impl<F, S, B, X, U> H1Service<Layer<SslFilter, F>, S, B, X, U>
where
F: Filter,
S: ServiceFactory<Request> + 'static,
@ -69,7 +68,8 @@ mod openssl {
X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<SslFilter<F>>, Codec), Response = ()> + 'static,
U: ServiceFactory<(Request, Io<Layer<SslFilter, F>>, Codec), Response = ()>
+ 'static,
U::Error: fmt::Display + Error,
U::InitError: fmt::Debug,
{
@ -102,9 +102,9 @@ mod rustls {
use tls_rustls::ServerConfig;
use super::*;
use crate::{server::SslError, service::pipeline_factory};
use crate::{io::Layer, server::SslError, service::pipeline_factory};
impl<F, S, B, X, U> H1Service<TlsFilter<F>, S, B, X, U>
impl<F, S, B, X, U> H1Service<Layer<TlsFilter, F>, S, B, X, U>
where
F: Filter,
S: ServiceFactory<Request> + 'static,
@ -115,7 +115,8 @@ mod rustls {
X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<TlsFilter<F>>, Codec), Response = ()> + 'static,
U: ServiceFactory<(Request, Io<Layer<TlsFilter, F>>, Codec), Response = ()>
+ 'static,
U::Error: fmt::Display + Error,
U::InitError: fmt::Debug,
{

View file

@ -64,6 +64,7 @@ impl Stream for Payload {
}
}
#[derive(Debug)]
/// Sender part of the payload stream
pub struct PayloadSender {
inner: Weak<RefCell<Inner>>,

View file

@ -49,13 +49,11 @@ mod openssl {
use ntex_tls::openssl::{Acceptor, SslFilter};
use tls_openssl::ssl::SslAcceptor;
use crate::io::Filter;
use crate::server::SslError;
use crate::service::pipeline_factory;
use crate::{io::Layer, server::SslError, service::pipeline_factory};
use super::*;
impl<F, S, B> H2Service<SslFilter<F>, S, B>
impl<F, S, B> H2Service<Layer<SslFilter, F>, S, B>
where
F: Filter,
S: ServiceFactory<Request> + 'static,
@ -90,9 +88,9 @@ mod rustls {
use tls_rustls::ServerConfig;
use super::*;
use crate::{server::SslError, service::pipeline_factory};
use crate::{io::Layer, server::SslError, service::pipeline_factory};
impl<F, S, B> H2Service<TlsFilter<F>, S, B>
impl<F, S, B> H2Service<Layer<TlsFilter, F>, S, B>
where
F: Filter,
S: ServiceFactory<Request> + 'static,
@ -394,7 +392,7 @@ where
loop {
match poll_fn(|cx| body.poll_next_chunk(cx)).await {
None => {
log::debug!("{:?} closing sending payload", msg.id());
log::debug!("{:?} closing payload stream", msg.id());
msg.stream().send_payload(Bytes::new(), true).await?;
break;
}

View file

@ -146,10 +146,9 @@ mod openssl {
use tls_openssl::ssl::SslAcceptor;
use super::*;
use crate::server::SslError;
use crate::service::pipeline_factory;
use crate::{io::Layer, server::SslError, service::pipeline_factory};
impl<F, S, B, X, U> HttpService<SslFilter<F>, S, B, X, U>
impl<F, S, B, X, U> HttpService<Layer<SslFilter, F>, S, B, X, U>
where
F: Filter,
S: ServiceFactory<Request> + 'static,
@ -160,7 +159,8 @@ mod openssl {
X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<SslFilter<F>>, h1::Codec), Response = ()> + 'static,
U: ServiceFactory<(Request, Io<Layer<SslFilter, F>>, h1::Codec), Response = ()>
+ 'static,
U::Error: fmt::Display + error::Error,
U::InitError: fmt::Debug,
{
@ -191,9 +191,9 @@ mod rustls {
use tls_rustls::ServerConfig;
use super::*;
use crate::{server::SslError, service::pipeline_factory};
use crate::{io::Layer, server::SslError, service::pipeline_factory};
impl<F, S, B, X, U> HttpService<TlsFilter<F>, S, B, X, U>
impl<F, S, B, X, U> HttpService<Layer<TlsFilter, F>, S, B, X, U>
where
F: Filter,
S: ServiceFactory<Request> + 'static,
@ -204,7 +204,8 @@ mod rustls {
X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<TlsFilter<F>>, h1::Codec), Response = ()> + 'static,
U: ServiceFactory<(Request, Io<Layer<TlsFilter, F>>, h1::Codec), Response = ()>
+ 'static,
U::Error: fmt::Display + error::Error,
U::InitError: fmt::Debug,
{

View file

@ -4,8 +4,9 @@ use std::{convert::TryFrom, net, str::FromStr, sync::mpsc, thread};
#[cfg(feature = "cookie")]
use coo_kie::{Cookie, CookieJar};
use crate::io::{Filter, Io};
use crate::ws::{error::WsClientError, WsClient, WsConnection};
use crate::{io::Filter, io::Io, rt::System, server::Server, service::ServiceFactory};
use crate::{rt::System, server::Server, service::ServiceFactory};
use crate::{time::Millis, time::Seconds, util::Bytes};
use super::client::{Client, ClientRequest, ClientResponse, Connector};
@ -349,7 +350,10 @@ impl TestServer {
/// Connect to a websocket server
pub async fn wss(
&mut self,
) -> Result<WsConnection<crate::connect::openssl::SslFilter>, WsClientError> {
) -> Result<
WsConnection<crate::io::Layer<crate::connect::openssl::SslFilter>>,
WsClientError,
> {
self.wss_at("/").await
}
@ -358,7 +362,10 @@ impl TestServer {
pub async fn wss_at(
&mut self,
path: &str,
) -> Result<WsConnection<crate::connect::openssl::SslFilter>, WsClientError> {
) -> Result<
WsConnection<crate::io::Layer<crate::connect::openssl::SslFilter>>,
WsClientError,
> {
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();

View file

@ -15,13 +15,13 @@ use crate::connect::{Connect, ConnectError, Connector};
use crate::http::header::{self, HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
use crate::http::{body::BodySize, client::ClientResponse, error::HttpError, h1};
use crate::http::{ConnectionType, RequestHead, RequestHeadType, StatusCode, Uri};
use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Sealed};
use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Layer, Sealed};
use crate::service::{apply_fn, into_service, IntoService, Service};
use crate::time::{timeout, Millis, Seconds};
use crate::{channel::mpsc, rt, util::Ready, ws};
use super::error::{WsClientBuilderError, WsClientError, WsError};
use super::transport::{WsTransport, WsTransportFactory};
use super::transport::WsTransport;
/// `WebSocket` client builder
pub struct WsClient<F, T> {
@ -527,7 +527,7 @@ where
pub fn openssl(
&mut self,
connector: openssl::SslConnector,
) -> WsClientBuilder<openssl::SslFilter, openssl::Connector<Uri>> {
) -> WsClientBuilder<Layer<openssl::SslFilter>, openssl::Connector<Uri>> {
self.connector(openssl::Connector::new(connector))
}
@ -536,7 +536,7 @@ where
pub fn rustls(
&mut self,
config: std::sync::Arc<rustls::ClientConfig>,
) -> WsClientBuilder<rustls::TlsFilter, rustls::Connector<Uri>> {
) -> WsClientBuilder<Layer<rustls::TlsFilter>, rustls::Connector<Uri>> {
self.connector(rustls::Connector::from(config))
}
@ -787,12 +787,8 @@ impl<F: Filter> WsConnection<F> {
}
/// Convert to ws stream to plain io stream
pub async fn into_transport(self) -> Io<WsTransport<F>> {
// WsTransportFactory is infallible
self.io
.add_filter(WsTransportFactory::new(self.codec))
.await
.unwrap()
pub fn into_transport(self) -> Io<Layer<WsTransport, F>> {
WsTransport::create(self.io, self.codec)
}
}

View file

@ -1,9 +1,9 @@
//! An implementation of WebSockets base bytes streams
use std::{any, cell::Cell, cmp, io, task::Context, task::Poll};
use std::{cell::Cell, cmp, io, task::Poll};
use crate::codec::{Decoder, Encoder};
use crate::io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
use crate::util::{BufMut, BytesVec, PoolRef, Ready};
use crate::io::{Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use crate::util::{BufMut, PoolRef, Ready};
use super::{CloseCode, CloseReason, Codec, Frame, Item, Message};
@ -16,15 +16,24 @@ bitflags::bitflags! {
}
/// An implementation of WebSockets streams
pub struct WsTransport<F = Base> {
inner: F,
pub struct WsTransport {
pool: PoolRef,
codec: Codec,
flags: Cell<Flags>,
read_buf: Cell<Option<BytesVec>>,
}
impl<F: Filter> WsTransport<F> {
impl WsTransport {
/// Create websockets transport
pub fn create<F: Filter>(io: Io<F>, codec: Codec) -> Io<Layer<WsTransport, F>> {
let pool = io.memory_pool();
io.add_filter(WsTransport {
pool,
codec,
flags: Cell::new(Flags::empty()),
})
}
fn insert_flags(&self, flags: Flags) {
let mut f = self.flags.get();
f.insert(flags);
@ -47,21 +56,12 @@ impl<F: Filter> WsTransport<F> {
}
}
impl<F: Filter> Filter for WsTransport<F> {
impl FilterLayer for WsTransport {
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
self.inner.query(id)
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result<Poll<()>> {
let flags = self.flags.get();
if !flags.contains(Flags::CLOSED) {
self.insert_flags(Flags::CLOSED);
let mut b = self
.inner
.get_write_buf()
.unwrap_or_else(|| self.pool.get_write_buf());
let code = if flags.contains(Flags::PROTO_ERR) {
CloseCode::Protocol
} else {
@ -72,159 +72,100 @@ impl<F: Filter> Filter for WsTransport<F> {
code,
description: None,
})),
&mut b,
buf.get_dst(),
);
self.inner.release_write_buf(b)?;
}
self.inner.poll_shutdown()
Ok(Poll::Ready(()))
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.poll_read_ready(cx)
}
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
if let Some(mut src) = buf.take_src() {
let mut dst = buf.take_dst();
let dst_len = dst.len();
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.poll_write_ready(cx)
}
loop {
// make sure we've got room
self.pool.resize_read_buf(&mut dst);
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
self.read_buf.take()
}
let frame = if let Some(frame) =
self.codec.decode_vec(&mut src).map_err(|e| {
log::trace!("Failed to decode ws codec frames: {:?}", e);
self.insert_flags(Flags::PROTO_ERR);
io::Error::new(io::ErrorKind::Other, e)
})? {
frame
} else {
break;
};
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
None
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
self.read_buf.set(Some(buf));
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
// ask inner filter to process read buf
match self.inner.process_read_buf(io, nbytes) {
Err(err) => io.want_shutdown(Some(err)),
Ok((_, 0)) => return Ok((0, 0)),
Ok(_) => (),
}
// get inner buffer
let mut src = if let Some(src) = self.inner.get_read_buf() {
src
} else {
return Ok((0, 0));
};
// get processed buffer
let mut dst = if let Some(dst) = self.read_buf.take() {
dst
} else {
self.pool.get_read_buf()
};
let dst_len = dst.len();
let (hw, lw) = self.pool.read_params().unpack();
loop {
// make sure we've got room
let remaining = dst.remaining_mut();
if remaining < lw {
dst.reserve(hw - remaining);
match frame {
Frame::Binary(bin) => dst.extend_from_slice(&bin),
Frame::Continuation(Item::FirstBinary(bin)) => {
self.insert_flags(Flags::CONTINUATION);
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Continue(bin)) => {
self.continuation_must_start("Continuation frame is not started")?;
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Last(bin)) => {
self.continuation_must_start(
"Continuation frame is not started, last frame is received",
)?;
dst.extend_from_slice(&bin);
self.remove_flags(Flags::CONTINUATION);
}
Frame::Continuation(Item::FirstText(_)) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSocket Text continuation frames are not supported",
));
}
Frame::Text(_) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSockets Text frames are not supported",
));
}
Frame::Ping(msg) => {
let _ = buf.with_write_buf(|b| {
self.codec.encode_vec(Message::Pong(msg), b.get_dst())
});
}
Frame::Pong(_) => (),
Frame::Close(_) => {
buf.want_shutdown();
break;
}
};
}
let frame = if let Some(frame) =
self.codec.decode_vec(&mut src).map_err(|e| {
log::trace!("Failed to decode ws codec frames: {:?}", e);
self.insert_flags(Flags::PROTO_ERR);
io::Error::new(io::ErrorKind::Other, e)
})? {
frame
} else {
break;
};
match frame {
Frame::Binary(bin) => dst.extend_from_slice(&bin),
Frame::Continuation(Item::FirstBinary(bin)) => {
self.insert_flags(Flags::CONTINUATION);
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Continue(bin)) => {
self.continuation_must_start("Continuation frame is not started")?;
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Last(bin)) => {
self.continuation_must_start(
"Continuation frame is not started, last frame is received",
)?;
dst.extend_from_slice(&bin);
self.remove_flags(Flags::CONTINUATION);
}
Frame::Continuation(Item::FirstText(_)) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSocket Text continuation frames are not supported",
));
}
Frame::Text(_) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSockets Text frames are not supported",
));
}
Frame::Ping(msg) => {
let mut b = self
.inner
.get_write_buf()
.unwrap_or_else(|| self.pool.get_write_buf());
let _ = self.codec.encode_vec(Message::Pong(msg), &mut b);
self.inner.release_write_buf(b)?;
}
Frame::Pong(_) => (),
Frame::Close(_) => {
io.want_shutdown(None);
break;
}
};
}
let dlen = dst.len();
let nbytes = dlen - dst_len;
if src.is_empty() {
self.pool.release_read_buf(src);
let nb = dst.len() - dst_len;
buf.set_dst(Some(dst));
buf.set_src(Some(src));
Ok(nb)
} else {
self.inner.release_read_buf(src);
Ok(0)
}
self.read_buf.set(Some(dst));
Ok((dlen, nbytes))
}
fn release_write_buf(&self, src: BytesVec) -> Result<(), io::Error> {
let mut buf = if let Some(buf) = self.inner.get_write_buf() {
buf
} else {
self.pool.get_write_buf()
};
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
if let Some(src) = buf.take_src() {
let dst = buf.get_dst();
// make sure we've got room
let (hw, lw) = self.pool.write_params().unpack();
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(cmp::max(hw, buf.len() + 12) - remaining);
// make sure we've got room
let (hw, lw) = self.pool.write_params().unpack();
let remaining = dst.remaining_mut();
if remaining < lw {
dst.reserve(cmp::max(hw, dst.len() + 12) - remaining);
}
// Encoder ws::Codec do not fail
let _ = self.codec.encode_vec(Message::Binary(src.freeze()), dst);
}
// Encoder ws::Codec do not fail
let _ = self
.codec
.encode_vec(Message::Binary(src.freeze()), &mut buf);
self.inner.release_write_buf(buf)
Ok(())
}
}
@ -241,22 +182,12 @@ impl WsTransportFactory {
}
impl<F: Filter> FilterFactory<F> for WsTransportFactory {
type Filter = WsTransport<F>;
type Filter = WsTransport;
type Error = io::Error;
type Future = Ready<Io<Self::Filter>, Self::Error>;
type Future = Ready<Io<Layer<Self::Filter, F>>, Self::Error>;
fn create(self, st: Io<F>) -> Self::Future {
let pool = st.memory_pool();
Ready::from(st.map_filter(|inner: F| {
Ok(WsTransport {
pool,
inner,
codec: self.codec,
flags: Cell::new(Flags::empty()),
read_buf: Cell::new(None),
})
}))
fn create(self, io: Io<F>) -> Self::Future {
Ready::Ok(WsTransport::create(io, self.codec))
}
}

View file

@ -494,11 +494,8 @@ async fn test_ws_transport() {
)
.unwrap();
let io = io
.add_filter(ws::WsTransportFactory::new(ws::Codec::default()))
.await?;
// start websocket service
let io = ws::WsTransport::create(io, ws::Codec::default());
while let Some(item) =
io.recv(&BytesCodec).await.map_err(|e| e.into_inner())?
{

View file

@ -254,9 +254,7 @@ async fn test_transport() {
)
.unwrap();
let io = io
.add_filter(ws::WsTransportFactory::new(ws::Codec::default()))
.await?;
let io = ws::WsTransport::create(io, ws::Codec::default());
// start websocket service
while let Some(item) =

View file

@ -95,7 +95,7 @@ async fn test_transport() {
});
// client service
let io = srv.ws().await.unwrap().into_transport().await;
let io = srv.ws().await.unwrap().into_transport();
io.send(Bytes::from_static(b"text"), &BytesCodec)
.await

View file

@ -120,6 +120,7 @@ async fn test_run() {
// stop
let _ = srv.stop(false).await;
thread::sleep(time::Duration::from_millis(100));
assert!(net::TcpStream::connect(addr).is_err());
thread::sleep(time::Duration::from_millis(100));
@ -250,7 +251,6 @@ fn test_configure_async() {
#[cfg(feature = "tokio")]
#[allow(unreachable_code)]
fn test_panic_in_worker() {
env_logger::init();
let counter = Arc::new(AtomicUsize::new(0));
let counter2 = counter.clone();