Add openssl filter (#69)

* add ntex-openssl

* cleanup io api

* add filter shutdown
This commit is contained in:
Nikolay Kim 2021-12-14 22:38:47 +06:00 committed by GitHub
parent 841ad736d4
commit dafd339817
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 1178 additions and 303 deletions

View file

@ -4,6 +4,7 @@ members = [
"ntex-bytes",
"ntex-codec",
"ntex-io",
"ntex-openssl",
"ntex-router",
"ntex-rt",
"ntex-service",
@ -16,6 +17,7 @@ ntex = { path = "ntex" }
ntex-bytes = { path = "ntex-bytes" }
ntex-codec = { path = "ntex-codec" }
ntex-io = { path = "ntex-io" }
ntex-openssl = { path = "ntex-openssl" }
ntex-router = { path = "ntex-router" }
ntex-rt = { path = "ntex-rt" }
ntex-service = { path = "ntex-service" }

View file

@ -37,3 +37,4 @@ tok-io = { version = "1", package = "tokio", default-features = false, features
ntex = "0.4.13"
futures = "0.3.13"
rand = "0.8"
env_logger = "0.9"

View file

@ -1,6 +1,6 @@
//! Framed transport dispatcher
use std::{
cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll, time,
cell::Cell, future::Future, io, pin::Pin, rc::Rc, task::Context, task::Poll, time,
};
use ntex_bytes::Pool;
@ -70,6 +70,7 @@ enum DispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
Io(io::Error),
}
enum PollService<U: Encoder + Decoder> {
@ -171,10 +172,19 @@ where
{
fn handle_result(&self, item: Result<S::Response, S::Error>, write: WriteRef<'_>) {
self.inflight.set(self.inflight.get() - 1);
match write.encode_result(item, &self.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(err) => self.error.set(Some(err.into())),
match item {
Ok(Some(val)) => match write.encode(val, &self.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(Either::Left(err)) => {
self.error.set(Some(DispatcherError::Encoder(err)))
}
Err(Either::Right(err)) => {
self.error.set(Some(DispatcherError::Io(err)))
}
},
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
Ok(None) => return,
}
write.wake_dispatcher();
}
@ -217,7 +227,10 @@ where
match slf.st.get() {
DispatcherState::Processing => {
let result = match slf.poll_service(this.service, cx, read) {
Poll::Pending => return Poll::Pending,
Poll::Pending => {
let _ = read.poll_ready(cx);
return Poll::Pending;
}
Poll::Ready(result) => result,
};
@ -237,8 +250,20 @@ where
}
Ok(None) => {
log::trace!("not enough data to decode next frame, register dispatch task");
read.wake(cx);
return Poll::Pending;
// service is ready, wake io read task
match read.poll_ready(cx) {
Poll::Pending
| Poll::Ready(Ok(Some(()))) => {
read.resume();
return Poll::Pending;
}
Poll::Ready(Ok(None)) => {
DispatchItem::Disconnect(None)
}
Poll::Ready(Err(err)) => {
DispatchItem::Disconnect(Some(err))
}
}
}
Err(err) => {
slf.st.set(DispatcherState::Stop);
@ -248,8 +273,18 @@ where
}
} else {
// no new events
state.register_dispatcher(cx);
return Poll::Pending;
match read.poll_ready(cx) {
Poll::Pending | Poll::Ready(Ok(Some(()))) => {
read.resume();
return Poll::Pending;
}
Poll::Ready(Ok(None)) => {
DispatchItem::Disconnect(None)
}
Poll::Ready(Err(err)) => {
DispatchItem::Disconnect(Some(err))
}
}
}
}
PollService::Item(item) => item,
@ -318,7 +353,7 @@ where
if slf.shared.inflight.get() == 0 {
slf.st.set(DispatcherState::Shutdown);
state.shutdown(cx);
state.init_shutdown(cx);
} else {
state.register_dispatcher(cx);
return Poll::Pending;
@ -368,15 +403,19 @@ where
item: Result<Option<<U as Encoder>::Item>, S::Error>,
write: WriteRef<'_>,
) {
match write.encode_result(item, &self.shared.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(Either::Left(err)) => {
self.error.set(Some(err));
}
Err(Either::Right(err)) => {
self.shared.error.set(Some(DispatcherError::Encoder(err)))
}
match item {
Ok(Some(item)) => match write.encode(item, &self.shared.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(Either::Left(err)) => {
self.shared.error.set(Some(DispatcherError::Encoder(err)))
}
Err(Either::Right(err)) => {
self.shared.error.set(Some(DispatcherError::Io(err)))
}
},
Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))),
Ok(None) => (),
}
}
@ -388,9 +427,6 @@ where
) -> Poll<PollService<U>> {
match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
// service is ready, wake io read task
read.resume();
// check keepalive timeout
self.check_keepalive();
@ -407,6 +443,9 @@ where
DispatcherError::Encoder(err) => {
PollService::Item(DispatchItem::EncoderError(err))
}
DispatcherError::Io(err) => {
PollService::Item(DispatchItem::Disconnect(Some(err)))
}
DispatcherError::Service(err) => {
self.error.set(Some(err));
PollService::ServiceError
@ -425,7 +464,7 @@ where
// get io error
if let Some(err) = self.state.take_error() {
PollService::Item(DispatchItem::IoError(err))
PollService::Item(DispatchItem::Disconnect(Some(err)))
} else {
PollService::ServiceError
}
@ -803,15 +842,15 @@ mod tests {
// response message
assert!(!state.write().is_ready());
assert_eq!(state.write().with_buf(|buf| buf.len()), 65536);
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 65536);
client.remote_buffer_cap(10240);
sleep(Millis(50)).await;
assert_eq!(state.write().with_buf(|buf| buf.len()), 55296);
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 55296);
client.remote_buffer_cap(45056);
sleep(Millis(50)).await;
assert_eq!(state.write().with_buf(|buf| buf.len()), 10240);
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 10240);
// backpressure disabled
assert!(state.write().is_ready());
@ -821,7 +860,6 @@ mod tests {
#[ntex::test]
async fn test_keepalive() {
let (client, server) = IoTest::create();
// do not allow to write to socket
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
@ -854,8 +892,7 @@ mod tests {
.keepalive_timeout(Seconds(1))
.await;
});
state.0.disconnect_timeout.set(Seconds(1));
state.0.disconnect_timeout.set(Millis::ONE_SEC);
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));

View file

@ -2,7 +2,7 @@ use std::{io, rc::Rc, task::Context, task::Poll};
use ntex_bytes::BytesMut;
use super::state::{Flags, IoStateInner};
use super::state::{Flags, IoRef, IoStateInner};
use super::{Filter, ReadFilter, WriteFilter, WriteReadiness};
pub struct DefaultFilter(Rc<IoStateInner>);
@ -13,7 +13,20 @@ impl DefaultFilter {
}
}
impl Filter for DefaultFilter {}
impl Filter for DefaultFilter {
#[inline]
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
let mut flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
flags.insert(Flags::IO_SHUTDOWN);
self.0.flags.set(flags);
self.0.read_task.wake();
self.0.write_task.wake();
}
Poll::Ready(Ok(()))
}
}
impl ReadFilter for DefaultFilter {
#[inline]
@ -48,20 +61,20 @@ impl ReadFilter for DefaultFilter {
}
#[inline]
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) {
if new_bytes > 0 {
if buf.len() > self.0.pool.get().read_params().high as usize {
log::trace!(
"buffer is too large {}, enable read back-pressure",
buf.len()
);
self.0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
} else {
self.0.insert_flags(Flags::RD_READY);
}
self.0.dispatch_task.wake();
fn release_read_buf(
&self,
buf: BytesMut,
new_bytes: usize,
) -> Result<(), io::Error> {
if new_bytes > 0 && buf.len() > self.0.pool.get().read_params().high as usize {
log::trace!(
"buffer is too large {}, enable read back-pressure",
buf.len()
);
self.0.insert_flags(Flags::RD_BUF_FULL);
}
self.0.read_buf.set(Some(buf));
Ok(())
}
}
@ -71,12 +84,23 @@ impl WriteFilter for DefaultFilter {
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
let flags = self.0.flags.get();
let mut flags = self.0.flags.get();
if flags.contains(Flags::IO_ERR) {
Poll::Ready(Err(WriteReadiness::Terminate))
} else if flags.intersects(Flags::IO_SHUTDOWN) {
Poll::Ready(Err(WriteReadiness::Shutdown))
Poll::Ready(Err(WriteReadiness::Shutdown(
self.0.disconnect_timeout.get(),
)))
} else if flags.contains(Flags::IO_FILTERS)
&& !flags.contains(Flags::IO_FILTERS_TO)
{
flags.insert(Flags::IO_FILTERS_TO);
self.0.flags.set(flags);
self.0.write_task.register(cx.waker());
Poll::Ready(Err(WriteReadiness::Timeout(
self.0.disconnect_timeout.get(),
)))
} else {
self.0.write_task.register(cx.waker());
Poll::Ready(Ok(()))
@ -100,13 +124,15 @@ impl WriteFilter for DefaultFilter {
}
#[inline]
fn release_write_buf(&self, buf: BytesMut) {
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
let pool = self.0.pool.get();
if buf.is_empty() {
pool.release_write_buf(buf);
} else {
self.0.write_buf.set(Some(buf));
self.0.write_task.wake();
}
Ok(())
}
}
@ -120,7 +146,11 @@ impl NullFilter {
}
}
impl Filter for NullFilter {}
impl Filter for NullFilter {
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
impl ReadFilter for NullFilter {
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
@ -133,7 +163,9 @@ impl ReadFilter for NullFilter {
None
}
fn release_read_buf(&self, _: BytesMut, _: usize) {}
fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<(), io::Error> {
Ok(())
}
}
impl WriteFilter for NullFilter {
@ -147,5 +179,7 @@ impl WriteFilter for NullFilter {
None
}
fn release_write_buf(&self, _: BytesMut) {}
fn release_write_buf(&self, _: BytesMut) -> Result<(), io::Error> {
Ok(())
}
}

View file

@ -14,19 +14,22 @@ mod tokio_impl;
use ntex_bytes::BytesMut;
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::Millis;
pub use self::dispatcher::Dispatcher;
pub use self::filter::DefaultFilter;
pub use self::state::{Io, IoRef, ReadRef, WriteRef};
pub use self::tasks::{ReadState, WriteState};
pub use self::time::Timer;
pub use self::utils::{from_iostream, into_boxed};
pub use self::utils::{filter_factory, from_iostream, into_boxed, into_io};
pub type IoBoxed = Io<Box<dyn Filter>>;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum WriteReadiness {
Shutdown,
Timeout(Millis),
Shutdown(Millis),
Terminate,
}
@ -37,7 +40,8 @@ pub trait ReadFilter {
fn get_read_buf(&self) -> Option<BytesMut>;
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize);
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize)
-> Result<(), io::Error>;
}
pub trait WriteFilter {
@ -48,10 +52,12 @@ pub trait WriteFilter {
fn get_write_buf(&self) -> Option<BytesMut>;
fn release_write_buf(&self, buf: BytesMut);
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>;
}
pub trait Filter: ReadFilter + WriteFilter {}
pub trait Filter: ReadFilter + WriteFilter {
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>>;
}
pub trait FilterFactory<F: Filter>: Sized {
type Filter: Filter;
@ -59,7 +65,7 @@ pub trait FilterFactory<F: Filter>: Sized {
type Error: fmt::Debug;
type Future: Future<Output = Result<Io<Self::Filter>, Self::Error>>;
fn create(&self, st: Io<F>) -> Self::Future;
fn create(self, st: Io<F>) -> Self::Future;
}
pub trait IoStream {
@ -79,8 +85,8 @@ pub enum DispatchItem<U: Encoder + Decoder> {
DecoderError(<U as Decoder>::Error),
/// Encoder parse error
EncoderError(<U as Encoder>::Error),
/// Unexpected io error
IoError(io::Error),
/// Socket is disconnected
Disconnect(Option<io::Error>),
}
impl<U> fmt::Debug for DispatchItem<U>
@ -108,8 +114,8 @@ where
DispatchItem::DecoderError(ref e) => {
write!(fmt, "DispatchItem::DecoderError({:?})", e)
}
DispatchItem::IoError(ref e) => {
write!(fmt, "DispatchItem::IoError({:?})", e)
DispatchItem::Disconnect(ref e) => {
write!(fmt, "DispatchItem::Disconnect({:?})", e)
}
}
}
@ -128,8 +134,8 @@ mod tests {
assert!(format!("{:?}", err).contains("DispatchItem::Encoder"));
let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::Decoder"));
let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::IoError"));
let err = T::Disconnect(Some(io::Error::new(io::ErrorKind::Other, "err")));
assert!(format!("{:?}", err).contains("DispatchItem::Disconnect"));
assert!(format!("{:?}", T::WBackPressureEnabled)
.contains("DispatchItem::WBackPressureEnabled"));

View file

@ -4,7 +4,8 @@ use std::{future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
use ntex_bytes::{BytesMut, PoolId, PoolRef};
use ntex_codec::{Decoder, Encoder};
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker, time::Seconds};
use ntex_util::time::{Millis, Seconds};
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker};
use super::filter::{DefaultFilter, NullFilter};
use super::tasks::{ReadState, WriteState};
@ -14,8 +15,12 @@ bitflags::bitflags! {
pub struct Flags: u16 {
/// io error occured
const IO_ERR = 0b0000_0000_0000_0001;
/// shuting down filters
const IO_FILTERS = 0b0000_0000_0000_0010;
/// shuting down filters timeout
const IO_FILTERS_TO = 0b0000_0000_0000_0100;
/// shutdown io tasks
const IO_SHUTDOWN = 0b0000_0000_0000_0100;
const IO_SHUTDOWN = 0b0000_0000_0000_1000;
/// pause io read
const RD_PAUSED = 0b0000_0000_0000_1000;
@ -51,7 +56,7 @@ pub struct IoRef(pub(super) Rc<IoStateInner>);
pub(crate) struct IoStateInner {
pub(super) flags: Cell<Flags>,
pub(super) pool: Cell<PoolRef>,
pub(super) disconnect_timeout: Cell<Seconds>,
pub(super) disconnect_timeout: Cell<Millis>,
pub(super) error: Cell<Option<io::Error>>,
pub(super) read_task: LocalWaker,
pub(super) write_task: LocalWaker,
@ -77,6 +82,16 @@ impl IoStateInner {
self.flags.set(flags);
}
#[inline]
pub(super) fn notify_keepalive(&self) {
let mut flags = self.flags.get();
if !flags.contains(Flags::DSP_KEEPALIVE) {
flags.insert(Flags::DSP_KEEPALIVE);
self.flags.set(flags);
self.dispatch_task.wake();
}
}
#[inline]
pub(super) fn notify_disconnect(&self) {
let mut on_disconnect = self.on_disconnect.borrow_mut();
@ -86,6 +101,36 @@ impl IoStateInner {
}
}
}
#[inline]
fn is_io_err(&self) -> bool {
self.flags.get().contains(Flags::IO_ERR)
}
#[inline]
pub(super) fn shutdown_filters(&self, st: &IoRef) -> Result<(), io::Error> {
let mut flags = self.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let result = match self.filter.get().shutdown(st) {
Poll::Pending => return Ok(()),
Poll::Ready(Ok(())) => {
flags.insert(Flags::IO_SHUTDOWN);
Ok(())
}
Poll::Ready(Err(err)) => {
flags.insert(Flags::IO_ERR);
self.dispatch_task.wake();
Err(err)
}
};
self.flags.set(flags);
self.read_task.wake();
self.write_task.wake();
result
} else {
Ok(())
}
}
}
impl Eq for IoStateInner {}
@ -130,7 +175,7 @@ impl Io {
pool: Cell::new(pool),
flags: Cell::new(Flags::empty()),
error: Cell::new(None),
disconnect_timeout: Cell::new(Seconds(1)),
disconnect_timeout: Cell::new(Millis::ONE_SEC),
dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(),
write_task: LocalWaker::new(),
@ -147,10 +192,12 @@ impl Io {
};
inner.filter.replace(filter_ref);
// start io tasks
io.start(ReadState(inner.clone()), WriteState(inner.clone()));
let io_ref = IoRef(inner);
Io(IoRef(inner), FilterItem::Ptr(Box::into_raw(filter)))
// start io tasks
io.start(ReadState(io_ref.clone()), WriteState(io_ref.clone()));
Io(io_ref, FilterItem::Ptr(Box::into_raw(filter)))
}
}
@ -172,7 +219,7 @@ impl<F> Io<F> {
#[inline]
/// Set io disconnect timeout in secs
pub fn set_disconnect_timeout(&self, timeout: Seconds) {
self.0 .0.disconnect_timeout.set(timeout);
self.0 .0.disconnect_timeout.set(timeout.into());
}
}
@ -201,6 +248,25 @@ impl<F> Io<F> {
pub fn dispatcher_stopped(&self) {
self.0 .0.insert_flags(Flags::DSP_STOP);
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub fn init_shutdown(&self, cx: &mut Context<'_>) {
let flags = self.0 .0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?}", flags);
self.0 .0.insert_flags(Flags::IO_FILTERS);
if let Err(err) = self.0 .0.shutdown_filters(&self.0) {
self.0 .0.error.set(Some(err));
self.0 .0.insert_flags(Flags::IO_ERR);
}
self.0 .0.read_task.wake();
self.0 .0.write_task.wake();
self.0 .0.dispatch_task.register(cx.waker());
}
}
}
impl IoRef {
@ -220,7 +286,7 @@ impl IoRef {
#[inline]
/// Check if io error occured in read or write task
pub fn is_io_err(&self) -> bool {
self.0.flags.get().contains(Flags::IO_ERR)
self.0.is_io_err()
}
#[inline]
@ -244,20 +310,6 @@ impl IoRef {
.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::DSP_STOP)
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub fn shutdown(&self, cx: &mut Context<'_>) {
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
log::trace!("initiate io shutdown {:?}", flags);
self.0.insert_flags(Flags::IO_SHUTDOWN);
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.register(cx.waker());
}
}
#[inline]
/// Take io error if any occured
pub fn take_error(&self) -> Option<io::Error> {
@ -331,32 +383,21 @@ impl<F> Io<F> {
};
self.0 .0.read_buf.set(buf);
let result = match item {
return match item {
Ok(Some(el)) => Ok(Some(el)),
Ok(None) => {
self.0 .0.remove_flags(Flags::RD_READY);
poll_fn(|cx| {
if read.is_ready() {
Poll::Ready(())
} else {
read.wake(cx);
Poll::Pending
}
})
.await;
if self.is_io_err() {
if let Some(err) = self.take_error() {
Err(Either::Right(err))
} else {
Ok(None)
}
} else {
continue;
if poll_fn(|cx| read.poll_ready(cx))
.await
.map_err(Either::Right)?
.is_none()
{
return Ok(None);
}
continue;
}
Err(err) => Err(Either::Left(err)),
};
return result;
}
}
@ -364,8 +405,8 @@ impl<F> Io<F> {
/// Encode item, send to a peer
pub async fn send<U>(
&self,
codec: &U,
item: U::Item,
codec: &U,
) -> Result<(), Either<U::Error, io::Error>>
where
U: Encoder,
@ -374,31 +415,46 @@ impl<F> Io<F> {
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
let is_write_sleep = buf.is_empty();
codec.encode(item, &mut buf).map_err(Either::Left)?;
filter.release_write_buf(buf);
filter.release_write_buf(buf).map_err(Either::Right)?;
self.0 .0.insert_flags(Flags::WR_WAIT);
if is_write_sleep {
self.0 .0.write_task.wake();
}
poll_fn(|cx| {
if !self.0 .0.flags.get().contains(Flags::WR_WAIT) || self.is_io_err() {
Poll::Ready(())
} else {
self.register_dispatcher(cx);
Poll::Pending
}
})
.await;
poll_fn(|cx| self.write().poll_flush(cx))
.await
.map_err(Either::Right)?;
Ok(())
}
if self.is_io_err() {
let err = self.0 .0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Internal error")
});
Err(Either::Right(err))
} else {
#[inline]
/// Shuts down connection
pub async fn shutdown(&self) -> Result<(), io::Error> {
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
Ok(())
} else {
poll_fn(|cx| {
let flags = self.flags();
if !flags.contains(Flags::IO_FILTERS) {
self.init_shutdown(cx);
}
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
if let Some(err) = self.0 .0.error.take() {
Poll::Ready(Err(err))
} else {
Poll::Ready(Ok(()))
}
} else {
self.0 .0.insert_flags(Flags::IO_FILTERS);
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
}
})
.await
}
}
@ -412,26 +468,52 @@ impl<F> Io<F> {
where
U: Decoder,
{
let mut buf = self.0 .0.read_buf.take();
let item = if let Some(ref mut buf) = buf {
codec.decode(buf)
} else {
Ok(None)
};
self.0 .0.read_buf.set(buf);
if self
.read()
.poll_ready(cx)
.map_err(Either::Right)?
.is_ready()
{
let mut buf = self.0 .0.read_buf.take();
let item = if let Some(ref mut buf) = buf {
codec.decode(buf)
} else {
Ok(None)
};
self.0 .0.read_buf.set(buf);
match item {
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
Ok(None) => {
self.read().wake(cx);
Poll::Pending
match item {
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
Ok(None) => {
if let Poll::Ready(res) =
self.read().poll_ready(cx).map_err(Either::Right)?
{
if res.is_none() {
return Poll::Ready(Ok(None));
}
}
Poll::Pending
}
Err(err) => Poll::Ready(Err(Either::Left(err))),
}
Err(err) => Poll::Ready(Err(Either::Left(err))),
} else {
Poll::Pending
}
}
}
impl<F: Filter> Io<F> {
#[inline]
/// Get referece to filter
pub fn filter(&self) -> &F {
if let FilterItem::Ptr(p) = self.1 {
if let Some(r) = unsafe { p.as_ref() } {
return r;
}
}
panic!()
}
#[inline]
pub fn into_boxed(mut self) -> crate::IoBoxed
where
@ -457,18 +539,18 @@ impl<F: Filter> Io<F> {
}
#[inline]
pub async fn add_filter<T>(self, factory: &T) -> Result<Io<T::Filter>, T::Error>
pub fn add_filter<T>(self, factory: T) -> T::Future
where
T: FilterFactory<F>,
{
factory.create(self).await
factory.create(self)
}
#[inline]
pub fn map_filter<T, U>(mut self, map: T) -> Io<U>
pub fn map_filter<T, U>(mut self, map: U) -> Result<Io<T::Filter>, T::Error>
where
T: FnOnce(F) -> U,
U: Filter,
T: FilterFactory<F>,
U: FnOnce(F) -> Result<T::Filter, T::Error>,
{
// replace current filter
let filter = unsafe {
@ -477,7 +559,7 @@ impl<F: Filter> Io<F> {
FilterItem::Boxed(_) => panic!(),
FilterItem::Ptr(p) => {
assert!(!p.is_null());
Box::new(map(*Box::from_raw(p)))
Box::new(map(*Box::from_raw(p))?)
}
};
let filter_ref: &'static dyn Filter = {
@ -488,7 +570,7 @@ impl<F: Filter> Io<F> {
filter
};
Io(self.0.clone(), FilterItem::Ptr(Box::into_raw(filter)))
Ok(Io(self.0.clone(), FilterItem::Ptr(Box::into_raw(filter))))
}
}
@ -562,7 +644,7 @@ impl<'a> WriteRef<'a> {
#[inline]
/// Get mut access to write buffer
pub fn with_buf<F, R>(&self, f: F) -> R
pub fn with_buf<F, R>(&self, f: F) -> Result<R, io::Error>
where
F: FnOnce(&mut BytesMut) -> R,
{
@ -575,8 +657,8 @@ impl<'a> WriteRef<'a> {
}
let result = f(&mut buf);
filter.release_write_buf(buf);
result
filter.release_write_buf(buf)?;
Ok(result)
}
#[inline]
@ -587,7 +669,7 @@ impl<'a> WriteRef<'a> {
&self,
item: U::Item,
codec: &U,
) -> Result<bool, <U as Encoder>::Error>
) -> Result<bool, Either<<U as Encoder>::Error, io::Error>>
where
U: Encoder,
{
@ -608,70 +690,45 @@ impl<'a> WriteRef<'a> {
}
// encode item and wake write task
let result = codec.encode(item, &mut buf).map(|_| {
if is_write_sleep {
self.0.write_task.wake();
}
buf.len() < hw
});
filter.release_write_buf(buf);
result
let result = codec
.encode(item, &mut buf)
.map(|_| {
if is_write_sleep {
self.0.write_task.wake();
}
buf.len() < hw
})
.map_err(Either::Left);
filter.release_write_buf(buf).map_err(Either::Right)?;
Ok(result?)
} else {
Ok(true)
}
}
#[inline]
/// Write item to a buf and wake up io task
pub fn encode_result<U, E>(
&self,
item: Result<Option<U::Item>, E>,
codec: &U,
) -> Result<bool, Either<E, U::Error>>
where
U: Encoder,
{
let flags = self.0.flags.get();
/// Wake write task and instruct to write all data.
///
/// When write task is done wake dispatcher.
pub fn poll_flush(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.0.insert_flags(Flags::WR_WAIT);
if !flags.intersects(Flags::IO_ERR | Flags::DSP_ERR) {
match item {
Ok(Some(item)) => {
let filter = self.0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
let is_write_sleep = buf.is_empty();
let (hw, lw) = self.0.pool.get().write_params().unpack();
// make sure we've got room
let remaining = buf.capacity() - buf.len();
if remaining < lw {
buf.reserve(hw - remaining);
}
// encode item
if let Err(err) = codec.encode(item, &mut buf) {
log::trace!("Encoder error: {:?}", err);
filter.release_write_buf(buf);
self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR);
self.0.dispatch_task.wake();
return Err(Either::Right(err));
} else if is_write_sleep {
self.0.write_task.wake();
}
let result = Ok(buf.len() < hw);
filter.release_write_buf(buf);
result
}
Err(err) => {
self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR);
self.0.dispatch_task.wake();
Err(Either::Left(err))
}
_ => Ok(true),
if let Some(buf) = self.0.write_buf.take() {
if !buf.is_empty() {
self.0.write_buf.set(Some(buf));
self.0.write_task.wake();
self.0.dispatch_task.register(cx.waker());
return Poll::Pending;
}
}
if self.0.is_io_err() {
Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "disconnected")
})))
} else {
Ok(true)
self.0.dispatch_task.register(cx.waker());
Poll::Ready(Ok(()))
}
}
}
@ -720,28 +777,6 @@ impl<'a> ReadRef<'a> {
}
}
#[inline]
/// Wake read task and instruct to read more data
///
/// Only wakes if back-pressure is enabled on read task
/// otherwise read is already awake.
pub fn wake(&self, cx: &mut Context<'_>) {
let mut flags = self.0.flags.get();
flags.remove(Flags::RD_READY);
if flags.contains(Flags::RD_BUF_FULL) {
log::trace!("read back-pressure is enabled, wake io task");
flags.remove(Flags::RD_BUF_FULL);
self.0.read_task.wake();
}
if flags.contains(Flags::RD_PAUSED) {
log::trace!("read is paused, wake io task");
flags.remove(Flags::RD_PAUSED);
self.0.read_task.wake();
}
self.0.flags.set(flags);
self.0.dispatch_task.register(cx.waker());
}
#[inline]
/// Attempts to decode a frame from the read buffer.
pub fn decode<U>(
@ -753,7 +788,11 @@ impl<'a> ReadRef<'a> {
{
let mut buf = self.0.read_buf.take();
let result = if let Some(ref mut buf) = buf {
codec.decode(buf)
let result = codec.decode(buf);
if result.as_ref().map(|v| v.is_none()).unwrap_or(false) {
self.0.remove_flags(Flags::RD_READY);
}
result
} else {
self.0.remove_flags(Flags::RD_READY);
Ok(None)
@ -775,12 +814,46 @@ impl<'a> ReadRef<'a> {
.unwrap_or_else(|| self.0.pool.get().get_read_buf());
let res = f(&mut buf);
if buf.is_empty() {
self.0.remove_flags(Flags::RD_READY);
self.0.pool.get().release_read_buf(buf);
} else {
self.0.read_buf.set(Some(buf));
}
res
}
#[inline]
/// Wake read task and instruct to read more data
///
/// Only wakes if back-pressure is enabled on read task
/// otherwise read is already awake.
pub fn poll_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<()>, io::Error>> {
let mut flags = self.0.flags.get();
let ready = flags.contains(Flags::RD_READY);
if self.0.is_io_err() {
if let Some(err) = self.0.error.take() {
Poll::Ready(Err(err))
} else {
Poll::Ready(Ok(None))
}
} else if ready {
Poll::Ready(Ok(Some(())))
} else {
flags.remove(Flags::RD_READY);
if flags.contains(Flags::RD_BUF_FULL) {
log::trace!("read back-pressure is enabled, wake io task");
flags.remove(Flags::RD_BUF_FULL);
self.0.read_task.wake();
}
self.0.flags.set(flags);
self.0.dispatch_task.register(cx.waker());
Poll::Pending
}
}
}
/// OnDisconnect future resolves when socket get disconnected
@ -866,6 +939,7 @@ mod tests {
#[ntex::test]
async fn utils() {
env_logger::init();
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write(TEXT);
@ -907,14 +981,14 @@ mod tests {
client.remote_buffer_cap(1024);
let state = Io::new(server);
state
.send(&BytesCodec, Bytes::from_static(b"test"))
.send(Bytes::from_static(b"test"), &BytesCodec)
.await
.unwrap();
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test"));
client.write_error(io::Error::new(io::ErrorKind::Other, "err"));
let res = state.send(&BytesCodec, Bytes::from_static(b"test")).await;
let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
assert!(res.is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::DSP_STOP));
@ -967,7 +1041,11 @@ mod tests {
in_bytes: Rc<Cell<usize>>,
out_bytes: Rc<Cell<usize>>,
}
impl<F: ReadFilter + WriteFilter> Filter for Counter<F> {}
impl<F: ReadFilter + WriteFilter> Filter for Counter<F> {
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
impl<F: ReadFilter> ReadFilter for Counter<F> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
@ -982,9 +1060,13 @@ mod tests {
self.inner.get_read_buf()
}
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) {
fn release_read_buf(
&self,
buf: BytesMut,
new_bytes: usize,
) -> Result<(), io::Error> {
self.in_bytes.set(self.in_bytes.get() + new_bytes);
self.inner.release_read_buf(buf, new_bytes);
self.inner.release_read_buf(buf, new_bytes)
}
}
@ -1009,9 +1091,9 @@ mod tests {
}
}
fn release_write_buf(&self, buf: BytesMut) {
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
self.out_bytes.set(self.out_bytes.get() + buf.len());
self.inner.release_write_buf(buf);
self.inner.release_write_buf(buf)
}
}
@ -1023,14 +1105,19 @@ mod tests {
type Error = ();
type Future = Ready<Io<Counter<F>>, Self::Error>;
fn create(&self, st: Io<F>) -> Self::Future {
fn create(self, io: Io<F>) -> Self::Future {
let in_bytes = self.0.clone();
let out_bytes = self.1.clone();
Ready::Ok(st.map_filter(|inner| Counter {
inner,
in_bytes,
out_bytes,
}))
Ready::Ok(
io.map_filter::<CounterFactory, _>(|inner| {
Ok(Counter {
inner,
in_bytes,
out_bytes,
})
})
.unwrap(),
)
}
}
@ -1041,7 +1128,7 @@ mod tests {
let factory = CounterFactory(in_bytes.clone(), out_bytes.clone());
let (client, server) = IoTest::create();
let state = Io::new(server).add_filter(&factory).await.unwrap();
let state = Io::new(server).add_filter(factory).await.unwrap();
client.remote_buffer_cap(1024);
client.write(TEXT);
@ -1049,7 +1136,7 @@ mod tests {
assert_eq!(msg, Bytes::from_static(BIN));
state
.send(&BytesCodec, Bytes::from_static(b"test"))
.send(Bytes::from_static(b"test"), &BytesCodec)
.await
.unwrap();
let buf = client.read().await.unwrap();
@ -1066,10 +1153,10 @@ mod tests {
let (client, server) = IoTest::create();
let state = Io::new(server)
.add_filter(&CounterFactory(in_bytes.clone(), out_bytes.clone()))
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
.await
.unwrap()
.add_filter(&CounterFactory(in_bytes.clone(), out_bytes.clone()))
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
.await
.unwrap();
let state = state.into_boxed();
@ -1080,7 +1167,7 @@ mod tests {
assert_eq!(msg, Bytes::from_static(BIN));
state
.send(&BytesCodec, Bytes::from_static(b"test"))
.send(Bytes::from_static(b"test"), &BytesCodec)
.await
.unwrap();
let buf = client.read().await.unwrap();

View file

@ -1,98 +1,115 @@
use std::{io, rc::Rc, task::Context, task::Poll};
use std::{io, task::Context, task::Poll};
use ntex_bytes::{BytesMut, PoolRef};
use ntex_util::time::Seconds;
use super::{state::Flags, state::IoStateInner, WriteReadiness};
use super::{state::Flags, IoRef, WriteReadiness};
pub struct ReadState(pub(super) Rc<IoStateInner>);
pub struct ReadState(pub(super) IoRef);
impl ReadState {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0.pool.get()
self.0 .0.pool.get()
}
#[inline]
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.0.filter.get().poll_read_ready(cx)
self.0 .0.filter.get().poll_read_ready(cx)
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0.filter.get().read_closed(err);
self.0 .0.filter.get().read_closed(err);
}
#[inline]
pub fn get_read_buf(&self) -> BytesMut {
self.0
.0
.filter
.get()
.get_read_buf()
.unwrap_or_else(|| self.0.pool.get().get_read_buf())
.unwrap_or_else(|| self.0 .0.pool.get().get_read_buf())
}
#[inline]
pub fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) {
pub fn release_read_buf(
&self,
buf: BytesMut,
new_bytes: usize,
) -> Result<(), io::Error> {
if buf.is_empty() {
self.0.pool.get().release_read_buf(buf);
self.0 .0.pool.get().release_read_buf(buf);
Ok(())
} else {
self.0.filter.get().release_read_buf(buf, new_bytes);
let mut flags = self.0 .0.flags.get();
// notify dispatcher
if new_bytes > 0 {
flags.insert(Flags::RD_READY);
self.0 .0.flags.set(flags);
self.0 .0.dispatch_task.wake();
}
self.0 .0.filter.get().release_read_buf(buf, new_bytes)?;
if flags.contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters(&self.0)?;
}
Ok(())
}
}
}
pub struct WriteState(pub(super) Rc<IoStateInner>);
pub struct WriteState(pub(super) IoRef);
impl WriteState {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0.pool.get()
}
#[inline]
pub fn disconnect_timeout(&self) -> Seconds {
self.0.disconnect_timeout.get()
self.0 .0.pool.get()
}
#[inline]
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), WriteReadiness>> {
self.0.filter.get().poll_write_ready(cx)
self.0 .0.filter.get().poll_write_ready(cx)
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0.filter.get().write_closed(err)
self.0 .0.filter.get().write_closed(err)
}
#[inline]
pub fn get_write_buf(&self) -> Option<BytesMut> {
self.0.write_buf.take()
self.0 .0.write_buf.take()
}
#[inline]
pub fn release_write_buf(&self, buf: BytesMut) {
let pool = self.0.pool.get();
pub fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
let pool = self.0 .0.pool.get();
let mut flags = self.0 .0.flags.get();
if buf.is_empty() {
pool.release_write_buf(buf);
let mut flags = self.0.flags.get();
if flags.intersects(Flags::WR_WAIT | Flags::WR_BACKPRESSURE) {
flags.remove(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
self.0.flags.set(flags);
self.0.dispatch_task.wake();
self.0 .0.flags.set(flags);
self.0 .0.dispatch_task.wake();
}
} else {
// if write buffer is smaller than high watermark value, turn off back-pressure
if buf.len() < pool.write_params_high() << 1 {
let mut flags = self.0.flags.get();
if flags.contains(Flags::WR_BACKPRESSURE) {
flags.remove(Flags::WR_BACKPRESSURE);
self.0.flags.set(flags);
self.0.dispatch_task.wake();
}
if buf.len() < pool.write_params_high() << 1
&& flags.contains(Flags::WR_BACKPRESSURE)
{
flags.remove(Flags::WR_BACKPRESSURE);
self.0 .0.flags.set(flags);
self.0 .0.dispatch_task.wake();
}
self.0.write_buf.set(Some(buf))
self.0 .0.write_buf.set(Some(buf))
}
if flags.contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters(&self.0)?;
}
Ok(())
}
}

View file

@ -5,7 +5,7 @@ use std::{
use ntex_util::spawn;
use ntex_util::time::{now, sleep, Millis};
use super::state::{Flags, IoRef, IoStateInner};
use super::state::{IoRef, IoStateInner};
pub struct Timer(Rc<RefCell<Inner>>);
@ -79,8 +79,7 @@ impl Timer {
let key = *key;
if key <= now_time {
for st in i.notifications.remove(&key).unwrap() {
st.dispatch_task.wake();
st.insert_flags(Flags::DSP_KEEPALIVE);
st.notify_keepalive();
}
} else {
break;

View file

@ -69,8 +69,13 @@ where
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
this.state.release_read_buf(buf, new_bytes);
this.state.close(None);
if let Err(e) =
this.state.release_read_buf(buf, new_bytes)
{
this.state.close(Some(e));
} else {
this.state.close(None);
}
return Poll::Ready(());
} else {
new_bytes += n;
@ -81,15 +86,19 @@ where
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
}
}
this.state.release_read_buf(buf, new_bytes);
Poll::Pending
if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
this.state.close(Some(e));
Poll::Ready(())
} else {
Poll::Pending
}
}
Poll::Pending => Poll::Pending,
}
@ -98,8 +107,8 @@ where
#[derive(Debug)]
enum IoWriteState {
Processing,
Shutdown(Option<Sleep>, Shutdown),
Processing(Option<Sleep>),
Shutdown(Sleep, Shutdown),
}
#[derive(Debug)]
@ -125,7 +134,7 @@ where
Self {
io,
state,
st: IoWriteState::Processing,
st: IoWriteState::Processing(None),
}
}
}
@ -140,22 +149,41 @@ where
let mut this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing => {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
if let Some(delay) = delay {
if delay.poll_elapsed(cx).is_ready() {
this.state.close(Some(io::Error::new(
io::ErrorKind::TimedOut,
"Operation timedout",
)));
return Poll::Ready(());
}
}
// flush framed instance
match flush_io(&mut *this.io.borrow_mut(), &this.state, cx) {
Poll::Pending | Poll::Ready(true) => Poll::Pending,
Poll::Ready(false) => Poll::Ready(()),
}
}
Poll::Ready(Err(WriteReadiness::Shutdown)) => {
Poll::Ready(Err(WriteReadiness::Timeout(time))) => {
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Shutdown(time))) => {
log::trace!("write task is instructed to shutdown");
this.st = IoWriteState::Shutdown(
this.state.disconnect_timeout().map(sleep),
Shutdown::None,
);
let timeout = if let Some(delay) = delay.take() {
delay
} else {
sleep(time)
};
this.st = IoWriteState::Shutdown(timeout, Shutdown::None);
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
@ -229,10 +257,8 @@ where
}
// disconnect timeout
if let Some(ref delay) = delay {
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
log::trace!("write task is stopped after delay");
this.state.close(None);
@ -290,11 +316,17 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
// remove written data
let result = if written == len {
buf.clear();
state.release_write_buf(buf);
if let Err(e) = state.release_write_buf(buf) {
state.close(Some(e));
return Poll::Ready(false);
}
Poll::Ready(true)
} else {
buf.advance(written);
state.release_write_buf(buf);
if let Err(e) = state.release_write_buf(buf) {
state.close(Some(e));
return Poll::Ready(false);
}
Poll::Pending
};

View file

@ -1,6 +1,9 @@
use ntex_service::{fn_factory_with_config, into_service, Service, ServiceFactory};
use std::{io, marker::PhantomData, task::Context, task::Poll};
use super::{Filter, Io, IoBoxed, IoStream};
use ntex_service::{fn_factory_with_config, into_service, Service, ServiceFactory};
use ntex_util::future::Ready;
use super::{Filter, FilterFactory, Io, IoBoxed, IoStream};
/// Service that converts any Io<F> stream to IoBoxed stream
pub fn into_boxed<F, S>(
@ -47,3 +50,81 @@ where
}
})
}
/// Service that converts IoStream stream to Io stream
pub fn into_io<I>() -> impl ServiceFactory<
Config = (),
Request = I,
Response = Io,
Error = io::Error,
InitError = (),
>
where
I: IoStream,
{
fn_factory_with_config(move |_: ()| {
Ready::Ok(into_service(move |io| Ready::Ok(Io::new(io))))
})
}
/// Create filter factory service
pub fn filter_factory<T, F>(filter: T) -> FilterServiceFactory<T, F>
where
T: FilterFactory<F> + Clone,
F: Filter,
{
FilterServiceFactory {
filter,
_t: PhantomData,
}
}
pub struct FilterServiceFactory<T, F> {
filter: T,
_t: PhantomData<F>,
}
impl<T, F> ServiceFactory for FilterServiceFactory<T, F>
where
T: FilterFactory<F> + Clone,
F: Filter,
{
type Config = ();
type Request = Io<F>;
type Response = Io<T::Filter>;
type Error = T::Error;
type Service = FilterService<T, F>;
type InitError = ();
type Future = Ready<Self::Service, Self::InitError>;
fn new_service(&self, _: ()) -> Self::Future {
Ready::Ok(FilterService {
filter: self.filter.clone(),
_t: PhantomData,
})
}
}
pub struct FilterService<T, F> {
filter: T,
_t: PhantomData<F>,
}
impl<T, F> Service for FilterService<T, F>
where
T: FilterFactory<F> + Clone,
F: Filter,
{
type Request = Io<F>;
type Response = Io<T::Filter>;
type Error = T::Error;
type Future = T::Future;
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&self, req: Io<F>) -> Self::Future {
req.add_filter(self.filter.clone())
}
}

27
ntex-openssl/Cargo.toml Normal file
View file

@ -0,0 +1,27 @@
[package]
name = "ntex-openssl"
version = "0.1.0"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
homepage = "https://ntex.rs"
repository = "https://github.com/ntex-rs/ntex.git"
documentation = "https://docs.rs/ntex-openssl/"
categories = ["network-programming", "asynchronous"]
license = "MIT"
edition = "2018"
[lib]
name = "ntex_openssl"
path = "src/lib.rs"
[dependencies]
ntex-bytes = "0.1.7"
ntex-io = "0.1.0"
ntex-util = "0.1.2"
openssl = "0.10.32"
[dev-dependencies]
ntex = { version = "0.4.14", features = ["openssl"] }
futures = "0.3"
env_logger = "0.9"

1
ntex-openssl/LICENSE Symbolic link
View file

@ -0,0 +1 @@
../LICENSE

View file

@ -0,0 +1,16 @@
-----BEGIN CERTIFICATE-----
MIICljCCAX4CCQDztMNlxk6oeTANBgkqhkiG9w0BAQsFADANMQswCQYDVQQIDAJj
YTAeFw0xOTAzMDcwNzEyNThaFw0yMDAzMDYwNzEyNThaMA0xCzAJBgNVBAgMAmNh
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0GMP3YzDVFWgNhRiHnfe
d192131Zi23p8WiutneD9I5WO42c79fOXsxLWn+2HSqPvCPHIBLoMX8o9lgCxt2P
/JUCAWbrE2EuvhkMrWk6/q7xB211XZYfnkqdt7mA0jMUC5o32AX3ew456TAq5P8Y
dq9H/qXdRtAvKD0QdkFfq8ePCiqOhcqacZ/NWva7R4HdgTnbL1DRQjGBXszI07P9
1yw8GOym46uxNHRujQp3lYEhc1V3JTF9kETpSBHyEAkQ8WHxGf8UBHDhh7hcc+KI
JHMlVYy5wDv4ZJeYsY1rD6/n4tyd3r0yzBM57UGf6qrVZEYmLB7Jad+8Df5vIoGh
WwIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQB1DEu9NiShCfQuA17MG5O0Jr2/PS1z
/+HW7oW15WXpqDKOEJalid31/Bzwvwq0bE12xKE4ZLdbqJHmJTdSUoGfOfBZKka6
R2thOjqH7hFvxjfgS7kBy5BrRZewM9xKIJ6zU6+6mxR64x9vmkOmppV0fx5clZjH
c7qn5kSNWTMsFbjPnb5BeJJwZdqpMLs99jgoMvGtCUmkyVYODGhh65g6tR9kIPvM
zu/Cw122/y7tFfkuknMSYwGEYF3XcZpXt54a6Lu5hk6PuOTsK+7lC+HX7CSF1dpv
u1szL5fDgiCBFCnyKeOqF61mxTCUht3U++37VDFvhzN1t6HIVTYm2JJ7
-----END CERTIFICATE-----

View file

@ -0,0 +1,35 @@
use std::io;
use ntex::{codec, connect, util::Bytes, util::Either};
use openssl::ssl::{self, SslMethod, SslVerifyMode};
#[ntex::main]
async fn main() -> io::Result<()> {
std::env::set_var("RUST_LOG", "trace");
env_logger::init();
println!("Connecting to openssl server: 127.0.0.1:8443");
// load ssl keys
let mut builder = ssl::SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_verify(SslVerifyMode::NONE);
let connector = builder.build();
// start server
let connector = connect::openssl::IoConnector::new(connector);
let io = connector.connect("127.0.0.1:8443").await.unwrap();
println!("Connected to ssl server");
let result = io
.send(Bytes::from_static(b"hello"), &codec::BytesCodec)
.await
.map_err(Either::into_inner)?;
let resp = io
.next(&codec::BytesCodec)
.await
.map_err(Either::into_inner)?;
println!("disconnecting");
io.shutdown().await
}

View file

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDQYw/djMNUVaA2
FGIed953X3bXfVmLbenxaK62d4P0jlY7jZzv185ezEtaf7YdKo+8I8cgEugxfyj2
WALG3Y/8lQIBZusTYS6+GQytaTr+rvEHbXVdlh+eSp23uYDSMxQLmjfYBfd7Djnp
MCrk/xh2r0f+pd1G0C8oPRB2QV+rx48KKo6Fyppxn81a9rtHgd2BOdsvUNFCMYFe
zMjTs/3XLDwY7Kbjq7E0dG6NCneVgSFzVXclMX2QROlIEfIQCRDxYfEZ/xQEcOGH
uFxz4ogkcyVVjLnAO/hkl5ixjWsPr+fi3J3evTLMEzntQZ/qqtVkRiYsHslp37wN
/m8igaFbAgMBAAECggEAJI278rkGany6pcHdlEqik34DcrliQ7r8FoSuYQOF+hgd
uESXCttoL+jWLwHICEW3AOGlxFKMuGH95Xh6xDeJUl0xBN3wzm11rZLnTmPvHU3C
qfLha5Ex6qpcECZSGo0rLv3WXeZuCv/r2KPCYnj86ZTFpD2kGw/Ztc1AXf4Jsi/1
478Mf23QmAvCAPimGCyjLQx2c9/vg/6K7WnDevY4tDuDKLeSJxKZBSHUn3cM1Bwj
2QzaHfSFA5XljOF5PLeR3cY5ncrrVLWChT9XuGt9YMdLAcSQxgE6kWV1RSCq+lbj
e6OOe879IrrqwBvMQfKQqnm1kl8OrfPMT5CNWKvEgQKBgQD8q5E4x9taDS9RmhRO
07ptsr/I795tX8CaJd/jc4xGuCGBqpNw/hVebyNNYQvpiYzDNBSEhtd59957VyET
hcrGyxD0ByKm8F/lPgFw5y6wi3RUnucCV/jxkMHmxVzYMbFUEGCQ0pIU9/GFS7RZ
9VjqRDeE86U3yHO+WCFoHtd8aQKBgQDTIhi0uq0oY87bUGnWbrrkR0UVRNPDG1BT
cuXACYlv/DV/XpxPC8iPK1UwG4XaOVxodtIRjdBqvb8fUM6HSY6qll64N/4/1jre
Ho+d4clE4tK6a9WU96CKxwHn2BrWUZJPtoldaCZJFJ7SfiHuLlqW7TtYFrOfPIjN
ADiqK+bHIwKBgQCpfIiAVwebo0Z/bWR77+iZFxMwvT4tjdJLVGaXUvXgpjjLmtkm
LTm2S8SZbiSodfz3H+M3dp/pj8wsXiiwyMlZifOITZT/+DPLOUmMK3cVM6ZH8QMy
fkJd/+UhYHhECSlTI10zKByXdi4LZNnIkhwfoLzBMRI9lfeV0dYu2qlfKQKBgEVI
kRbtk1kHt5/ceX62g3nZsV/TYDJMSkW4FJC6EHHBL8UGRQDjewMQUzogLgJ4hEx7
gV/lS5lbftZF7CAVEU4FXjvRlAtav6KYIMTMjQGf9UrbjBEAWZxwxb1Q+y2NQxgJ
bHZMcRPWQnAMmBHTAEM6whicCoGcmb+77Nxa37ZFAoGBALBuUNeD3fKvQR8v6GoA
spv+RYL9TB4wz2Oe9EYSp9z5EiWlTmuvFz3zk8pHDSpntxYH5O5HJ/3OzwhHz9ym
+DNE9AP9LW9hAzMuu7Gob1h8ShGwJVYwrQN3q/83ooUL7WSAuVOLpzJ7BFFlcCjp
MhFvd9iOt/R0N30/3AbQXkOp
-----END PRIVATE KEY-----

View file

@ -0,0 +1,54 @@
use std::io;
use ntex::service::{fn_service, pipeline_factory};
use ntex::{codec, io::filter_factory, io::into_io, io::Io, server, util::Either};
use ntex_openssl::SslAcceptor;
use openssl::ssl::{self, SslFiletype, SslMethod};
#[ntex::main]
async fn main() -> io::Result<()> {
std::env::set_var("RUST_LOG", "trace");
env_logger::init();
println!("Started openssl echp server: 127.0.0.1:8443");
// load ssl keys
let mut builder = ssl::SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
builder
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
.unwrap();
builder
.set_certificate_chain_file("../tests/cert.pem")
.unwrap();
let acceptor = builder.build();
// start server
server::ServerBuilder::new()
.bind("basic", "127.0.0.1:8443", move || {
pipeline_factory(into_io())
.and_then(filter_factory(SslAcceptor::new(acceptor.clone())))
.and_then(fn_service(|io: Io<_>| async move {
println!("New client is connected");
loop {
match io.next(&codec::BytesCodec).await {
Ok(Some(msg)) => {
println!("Got message: {:?}", msg);
io.send(msg.freeze(), &codec::BytesCodec)
.await
.map_err(Either::into_inner)?;
}
Ok(None) => break,
Err(e) => {
println!("Got error: {:?}", e);
break;
}
}
}
println!("Client is disconnected");
Ok(())
}))
})?
.workers(1)
.run()
.await
}

320
ntex-openssl/src/lib.rs Normal file
View file

@ -0,0 +1,320 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::cell::RefCell;
use std::{cmp, error::Error, future::Future, io, pin::Pin, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesMut};
use ntex_io::{
Filter, FilterFactory, Io, IoRef, ReadFilter, WriteFilter, WriteReadiness,
};
use ntex_util::{future::poll_fn, time, time::Millis};
use openssl::ssl::{self, SslStream};
pub struct SslFilter<F> {
inner: RefCell<SslStream<IoInner<F>>>,
}
struct IoInner<F> {
inner: F,
read_buf: Option<BytesMut>,
write_buf: Option<BytesMut>,
}
impl<F: Filter> io::Read for IoInner<F> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(ref mut buf) = self.read_buf {
if buf.is_empty() {
buf.clear();
Err(io::Error::from(io::ErrorKind::WouldBlock))
} else {
let len = cmp::min(buf.len(), dst.len());
dst.copy_from_slice(&buf.split_to(len));
Ok(len)
}
} else {
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
}
}
impl<F: Filter> io::Write for IoInner<F> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.inner.get_write_buf() {
buf.reserve(buf.len());
buf
} else {
BytesMut::with_capacity(src.len())
};
buf.extend_from_slice(src);
self.inner.release_write_buf(buf)?;
Ok(src.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<F: Filter> Filter for SslFilter<F> {
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>> {
let ssl_result = self.inner.borrow_mut().shutdown();
match ssl_result {
Ok(ssl::ShutdownResult::Sent) => Poll::Pending,
Ok(ssl::ShutdownResult::Received) => {
self.inner.borrow().get_ref().inner.shutdown(st)
}
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Poll::Ready(Ok(())),
Err(ref e)
if e.code() == ssl::ErrorCode::WANT_READ
|| e.code() == ssl::ErrorCode::WANT_WRITE =>
{
Poll::Pending
}
Err(e) => Poll::Ready(Err(e
.into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))),
}
}
}
impl<F: Filter> ReadFilter for SslFilter<F> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.inner.borrow().get_ref().inner.poll_read_ready(cx)
}
fn read_closed(&self, err: Option<io::Error>) {
self.inner.borrow().get_ref().inner.read_closed(err)
}
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
}
fn release_read_buf(
&self,
src: BytesMut,
new_bytes: usize,
) -> Result<(), io::Error> {
// store to read_buf
self.inner.borrow_mut().get_mut().read_buf = Some(src);
if new_bytes == 0 {
return Ok(());
}
let mut buf =
if let Some(buf) = self.inner.borrow().get_ref().inner.get_read_buf() {
buf
} else {
BytesMut::with_capacity(4096)
};
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
let result = match ssl_result {
Ok(v) => {
unsafe { buf.advance_mut(v) };
self.inner
.borrow()
.get_ref()
.inner
.release_read_buf(buf, v)?;
Ok(())
}
Err(e) => match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()),
_ => (Err(map_to_ioerr(e))),
},
};
result
}
}
impl<F: Filter> WriteFilter for SslFilter<F> {
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
self.inner.borrow().get_ref().inner.poll_write_ready(cx)
}
fn write_closed(&self, err: Option<io::Error>) {
self.inner.borrow().get_ref().inner.read_closed(err)
}
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
}
fn release_write_buf(&self, mut buf: BytesMut) -> Result<(), io::Error> {
let ssl_result = self.inner.borrow_mut().ssl_write(&buf);
let result = match ssl_result {
Ok(v) => {
if v != buf.len() {
buf.split_to(v);
self.inner.borrow_mut().get_mut().write_buf = Some(buf);
}
Ok(())
}
Err(e) => match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()),
_ => (Err(map_to_ioerr(e))),
},
};
result
}
}
pub struct SslAcceptor {
acceptor: ssl::SslAcceptor,
timeout: Millis,
}
impl SslAcceptor {
/// Create openssl acceptor filter factory
pub fn new(acceptor: ssl::SslAcceptor) -> Self {
SslAcceptor {
acceptor,
timeout: Millis(5_000),
}
}
/// Set handshake timeout.
///
/// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
self.timeout = timeout.into();
self
}
}
impl Clone for SslAcceptor {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
timeout: self.timeout,
}
}
}
impl<F: Filter + 'static> FilterFactory<F> for SslAcceptor {
type Filter = SslFilter<F>;
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, Self::Error>>>>;
fn create(self, st: Io<F>) -> Self::Future {
let timeout = self.timeout;
let ctx_result = ssl::Ssl::new(self.acceptor.context());
Box::pin(async move {
time::timeout(timeout, async {
let ssl = ctx_result.map_err(map_to_ioerr)?;
let st = st.map_filter::<Self, _>(|inner: F| {
let inner = IoInner {
inner,
read_buf: None,
write_buf: None,
};
let ssl_stream =
ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?;
Ok(SslFilter {
inner: RefCell::new(ssl_stream),
})
})?;
poll_fn(|cx| {
let _ = st.write().poll_flush(cx)?;
handle_result(st.filter().inner.borrow_mut().accept(), &st, cx)
.map_err(map_to_ioerr)
})
.await?;
Ok(st)
})
.await
.map_err(|_| {
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout")
})
.and_then(|item| item)
})
}
}
pub struct SslConnector {
ssl: ssl::Ssl,
}
impl SslConnector {
/// Create openssl connector filter factory
pub fn new(ssl: ssl::Ssl) -> Self {
SslConnector { ssl }
}
}
impl<F: Filter + 'static> FilterFactory<F> for SslConnector {
type Filter = SslFilter<F>;
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, Self::Error>>>>;
fn create(self, st: Io<F>) -> Self::Future {
Box::pin(async move {
let ssl = self.ssl;
let st = st.map_filter::<Self, _>(|inner: F| {
let inner = IoInner {
inner,
read_buf: None,
write_buf: None,
};
let ssl_stream =
ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?;
Ok(SslFilter {
inner: RefCell::new(ssl_stream),
})
})?;
poll_fn(|cx| {
let _ = st.write().poll_flush(cx)?;
handle_result(st.filter().inner.borrow_mut().connect(), &st, cx)
.map_err(map_to_ioerr)
})
.await?;
Ok(st)
})
}
}
fn handle_result<T: std::fmt::Debug>(
result: Result<T, ssl::Error>,
st: &IoRef,
cx: &mut Context<'_>,
) -> Poll<Result<T, ssl::Error>> {
match result {
Ok(v) => Poll::Ready(Ok(v)),
Err(e) => match e.code() {
ssl::ErrorCode::WANT_READ => {
let _ = st.read().poll_ready(cx);
Poll::Pending
}
ssl::ErrorCode::WANT_WRITE => Poll::Pending,
_ => Poll::Ready(Err(e)),
},
}
}
fn map_to_ioerr<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
}

View file

@ -24,7 +24,7 @@ path = "src/lib.rs"
default = ["http-framework"]
# openssl
openssl = ["open-ssl", "tokio-openssl"]
openssl = ["open-ssl", "tokio-openssl", "ntex-openssl"]
# rustls support
rustls = ["rust-tls", "rustls-pemfile", "tokio-rustls", "webpki", "webpki-roots"]
@ -51,6 +51,7 @@ ntex-macros = "0.1.3"
ntex-util = "0.1.2"
ntex-bytes = "0.1.7"
ntex-io = { version = "0.1", features = ["tokio"] }
ntex-openssl = { version = "0.1", optional = true }
base64 = "0.13"
bitflags = "1.3"

View file

@ -1,13 +1,15 @@
use std::{future::Future, io, pin::Pin, task::Context, task::Poll};
use ntex_openssl::{SslConnector as IoSslConnector, SslFilter};
pub use open_ssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
pub use tokio_openssl::SslStream;
use crate::io::{DefaultFilter, Io};
use crate::rt::net::TcpStream;
use crate::service::{Service, ServiceFactory};
use crate::util::Ready;
use super::{Address, Connect, ConnectError, Connector};
use super::{Address, Connect, ConnectError, Connector, IoConnector as BaseIoConnector};
pub struct OpensslConnector<T> {
connector: Connector<T>,
@ -106,6 +108,101 @@ impl<T: Address + 'static> Service for OpensslConnector<T> {
}
}
pub struct IoConnector<T> {
connector: BaseIoConnector<T>,
openssl: SslConnector,
}
impl<T> IoConnector<T> {
/// Construct new OpensslConnectService factory
pub fn new(connector: SslConnector) -> Self {
IoConnector {
connector: BaseIoConnector::default(),
openssl: connector,
}
}
}
impl<T: Address + 'static> IoConnector<T> {
/// Resolve and connect to remote host
pub fn connect<U>(
&self,
message: U,
) -> impl Future<Output = Result<Io<SslFilter<DefaultFilter>>, ConnectError>>
where
Connect<T>: From<U>,
{
let message = Connect::from(message);
let host = message.host().to_string();
let conn = self.connector.call(message);
let openssl = self.openssl.clone();
async move {
let io = conn.await?;
trace!("SSL Handshake start for: {:?}", host);
match openssl.configure() {
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()),
Ok(config) => {
let ssl = config
.into_ssl(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
match io.add_filter(IoSslConnector::new(ssl)).await {
Ok(io) => {
trace!("SSL Handshake success: {:?}", host);
Ok(io)
}
Err(e) => {
trace!("SSL Handshake error: {:?}", e);
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))
.into())
}
}
}
}
}
}
}
impl<T> Clone for IoConnector<T> {
fn clone(&self) -> Self {
IoConnector {
connector: self.connector.clone(),
openssl: self.openssl.clone(),
}
}
}
impl<T: Address + 'static> ServiceFactory for IoConnector<T> {
type Request = Connect<T>;
type Response = Io<SslFilter<DefaultFilter>>;
type Error = ConnectError;
type Config = ();
type Service = IoConnector<T>;
type InitError = ();
type Future = Ready<Self::Service, Self::InitError>;
fn new_service(&self, _: ()) -> Self::Future {
Ready::Ok(self.clone())
}
}
impl<T: Address + 'static> Service for IoConnector<T> {
type Request = Connect<T>;
type Response = Io<SslFilter<DefaultFilter>>;
type Error = ConnectError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
#[inline]
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&self, req: Connect<T>) -> Self::Future {
Box::pin(self.connect(req))
}
}
#[cfg(test)]
mod tests {
use super::*;