impl rustls acceptor; update filter trait

This commit is contained in:
Nikolay Kim 2021-12-19 01:43:43 +06:00
parent d7083c15d8
commit 1af728eb01
30 changed files with 1184 additions and 306 deletions

9
ntex-io/CHANGES.md Normal file
View file

@ -0,0 +1,9 @@
# Changes
## [0.1.0-b.1] - 2021-12-18
* Modify filter's release_read/write_buf return type
## [0.1.0-b.0] - 2021-12-18
* Refactor ntex::framed to ntex-io

View file

@ -339,7 +339,6 @@ where
if slf.shared.inflight.get() == 0 {
slf.st.set(DispatcherState::Shutdown);
state.init_shutdown(cx);
} else {
state.register_dispatcher(cx);
return Poll::Pending;

View file

@ -69,7 +69,7 @@ impl ReadFilter for DefaultFilter {
&self,
buf: BytesMut,
new_bytes: usize,
) -> Result<(), io::Error> {
) -> Result<bool, io::Error> {
let mut flags = self.0.flags.get();
if new_bytes > 0 {
@ -86,7 +86,7 @@ impl ReadFilter for DefaultFilter {
}
self.0.read_buf.set(Some(buf));
Ok(())
Ok(false)
}
}
@ -133,7 +133,7 @@ impl WriteFilter for DefaultFilter {
}
#[inline]
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
fn release_write_buf(&self, buf: BytesMut) -> Result<bool, io::Error> {
let pool = self.0.pool.get();
if buf.is_empty() {
pool.release_write_buf(buf);
@ -141,7 +141,7 @@ impl WriteFilter for DefaultFilter {
self.0.write_buf.set(Some(buf));
self.0.write_task.wake();
}
Ok(())
Ok(false)
}
}
@ -176,8 +176,8 @@ impl ReadFilter for NullFilter {
None
}
fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<(), io::Error> {
Ok(())
fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<bool, io::Error> {
Ok(true)
}
}
@ -192,7 +192,7 @@ impl WriteFilter for NullFilter {
None
}
fn release_write_buf(&self, _: BytesMut) -> Result<(), io::Error> {
Ok(())
fn release_write_buf(&self, _: BytesMut) -> Result<bool, io::Error> {
Ok(true)
}
}

View file

@ -43,8 +43,7 @@ pub trait ReadFilter {
fn get_read_buf(&self) -> Option<BytesMut>;
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize)
-> Result<(), io::Error>;
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<bool, io::Error>;
}
pub trait WriteFilter {
@ -55,7 +54,7 @@ pub trait WriteFilter {
fn get_write_buf(&self) -> Option<BytesMut>;
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>;
fn release_write_buf(&self, buf: BytesMut) -> Result<bool, io::Error>;
}
pub trait Filter: ReadFilter + WriteFilter + 'static {

View file

@ -124,6 +124,28 @@ impl IoStateInner {
self.notify_disconnect();
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, cx: Option<&mut Context<'_>>, st: &IoRef) {
let mut flags = self.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?}", flags);
flags.insert(Flags::IO_FILTERS);
if let Err(err) = self.shutdown_filters(st) {
self.error.set(Some(err));
flags.insert(Flags::IO_ERR);
}
self.flags.set(flags);
self.read_task.wake();
self.write_task.wake();
if let Some(cx) = cx {
self.dispatch_task.register(cx.waker());
}
}
}
#[inline]
pub(super) fn shutdown_filters(&self, st: &IoRef) -> Result<(), io::Error> {
let mut flags = self.flags.get();
@ -328,13 +350,13 @@ impl IoRef {
#[inline]
/// Get api for read task
pub fn read(&'_ self) -> ReadRef<'_> {
ReadRef(self.0.as_ref())
ReadRef(self)
}
#[inline]
/// Get api for write task
pub fn write(&'_ self) -> WriteRef<'_> {
WriteRef(self.0.as_ref())
WriteRef(self)
}
#[inline]
@ -425,7 +447,7 @@ impl IoRef {
Poll::Ready(Ok(()))
} else {
if !flags.contains(Flags::IO_FILTERS) {
self.init_shutdown(cx);
self.0.init_shutdown(Some(cx), self);
}
if let Some(err) = self.0.error.take() {
@ -465,25 +487,6 @@ impl IoRef {
Err(err) => Poll::Ready(Err(Either::Left(err))),
}
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, cx: &mut Context<'_>) {
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?}", flags);
self.0.insert_flags(Flags::IO_FILTERS);
if let Err(err) = self.0.shutdown_filters(self) {
self.0.error.set(Some(err));
self.0.insert_flags(Flags::IO_ERR);
}
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.register(cx.waker());
}
}
}
impl fmt::Debug for IoRef {
@ -539,10 +542,10 @@ impl<F: Filter> Io<F> {
}
#[inline]
pub fn map_filter<T, U>(mut self, map: U) -> Result<Io<T::Filter>, T::Error>
pub fn map_filter<T, U, E>(mut self, map: U) -> Result<Io<T>, E>
where
T: FilterFactory<F>,
U: FnOnce(F) -> Result<T::Filter, T::Error>,
T: Filter,
U: FnOnce(F) -> Result<T, E>,
{
// replace current filter
let filter = unsafe {
@ -610,22 +613,22 @@ impl<F> Deref for Io<F> {
}
#[derive(Copy, Clone)]
pub struct WriteRef<'a>(pub(super) &'a IoStateInner);
pub struct WriteRef<'a>(pub(super) &'a IoRef);
impl<'a> WriteRef<'a> {
#[inline]
/// Check if write task is ready
pub fn is_ready(&self) -> bool {
!self.0.flags.get().contains(Flags::WR_BACKPRESSURE)
!self.0 .0.flags.get().contains(Flags::WR_BACKPRESSURE)
}
#[inline]
/// Check if write buffer is full
pub fn is_full(&self) -> bool {
if let Some(buf) = self.0.read_buf.take() {
let hw = self.0.pool.get().write_params_high();
if let Some(buf) = self.0 .0.read_buf.take() {
let hw = self.0 .0.pool.get().write_params_high();
let result = buf.len() >= hw;
self.0.write_buf.set(Some(buf));
self.0 .0.write_buf.set(Some(buf));
result
} else {
false
@ -635,7 +638,7 @@ impl<'a> WriteRef<'a> {
#[inline]
/// Wake dispatcher task
pub fn wake_dispatcher(&self) {
self.0.dispatch_task.wake();
self.0 .0.dispatch_task.wake();
}
#[inline]
@ -644,9 +647,9 @@ impl<'a> WriteRef<'a> {
/// Write task must be waken up separately.
pub fn enable_backpressure(&self, cx: Option<&mut Context<'_>>) {
log::trace!("enable write back-pressure {:?}", cx.is_some());
self.0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
if let Some(cx) = cx {
self.0.dispatch_task.register(cx.waker());
self.0 .0.dispatch_task.register(cx.waker());
}
}
@ -656,16 +659,19 @@ impl<'a> WriteRef<'a> {
where
F: FnOnce(&mut BytesMut) -> R,
{
let filter = self.0.filter.get();
let filter = self.0 .0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
if buf.is_empty() {
self.0.write_task.wake();
self.0 .0.write_task.wake();
}
let result = f(&mut buf);
filter.release_write_buf(buf)?;
let close = filter.release_write_buf(buf)?;
if close {
self.0 .0.init_shutdown(None, self.0);
}
Ok(result)
}
@ -681,15 +687,15 @@ impl<'a> WriteRef<'a> {
where
U: Encoder,
{
let flags = self.0.flags.get();
let flags = self.0 .0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let filter = self.0.filter.get();
let filter = self.0 .0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
let is_write_sleep = buf.is_empty();
let (hw, lw) = self.0.pool.get().write_params().unpack();
let (hw, lw) = self.0 .0.pool.get().write_params().unpack();
// make sure we've got room
let remaining = buf.capacity() - buf.len();
@ -700,12 +706,19 @@ impl<'a> WriteRef<'a> {
// encode item and wake write task
let result = codec.encode(item, &mut buf).map(|_| {
if is_write_sleep {
self.0.write_task.wake();
self.0 .0.write_task.wake();
}
buf.len() < hw
});
if let Err(err) = filter.release_write_buf(buf) {
self.0.set_error(Some(err));
match filter.release_write_buf(buf) {
Err(err) => {
self.0 .0.set_error(Some(err));
}
Ok(close) => {
if close {
self.0 .0.init_shutdown(None, self.0);
}
}
}
result
} else {
@ -718,24 +731,24 @@ impl<'a> WriteRef<'a> {
///
/// Returns write buffer state, false is returned if write buffer if full.
pub fn write(&self, src: &[u8]) -> Result<bool, io::Error> {
let flags = self.0.flags.get();
let flags = self.0 .0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let filter = self.0.filter.get();
let filter = self.0 .0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
let is_write_sleep = buf.is_empty();
// write and wake write task
buf.extend_from_slice(src);
let result = buf.len() < self.0.pool.get().write_params_high();
let result = buf.len() < self.0 .0.pool.get().write_params_high();
if is_write_sleep {
self.0.write_task.wake();
self.0 .0.write_task.wake();
}
if let Err(err) = filter.release_write_buf(buf) {
self.0.set_error(Some(err));
self.0 .0.set_error(Some(err));
}
Ok(result)
} else {
@ -755,27 +768,27 @@ impl<'a> WriteRef<'a> {
full: bool,
) -> Poll<Result<(), io::Error>> {
// check io error
if !self.0.is_io_open() {
return Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| {
if !self.0 .0.is_io_open() {
return Poll::Ready(Err(self.0 .0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "disconnected")
})));
}
if let Some(buf) = self.0.write_buf.take() {
if let Some(buf) = self.0 .0.write_buf.take() {
let len = buf.len();
if len != 0 {
self.0.write_buf.set(Some(buf));
self.0 .0.write_buf.set(Some(buf));
if full {
self.0.insert_flags(Flags::WR_WAIT);
self.0.dispatch_task.register(cx.waker());
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= self.0.pool.get().write_params_high() << 1 {
self.0.insert_flags(Flags::WR_BACKPRESSURE);
self.0.dispatch_task.register(cx.waker());
} else if len >= self.0 .0.pool.get().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else {
self.0.remove_flags(Flags::WR_BACKPRESSURE);
self.0 .0.remove_flags(Flags::WR_BACKPRESSURE);
}
}
}
@ -793,21 +806,30 @@ impl<'a> WriteRef<'a> {
}
#[derive(Copy, Clone)]
pub struct ReadRef<'a>(&'a IoStateInner);
pub struct ReadRef<'a>(&'a IoRef);
impl<'a> ReadRef<'a> {
#[inline]
/// Check if read buffer has new data
pub fn is_ready(&self) -> bool {
self.0.flags.get().contains(Flags::RD_READY)
self.0 .0.flags.get().contains(Flags::RD_READY)
}
/// Reset readiness state, returns previous state
pub fn take_readiness(&self) -> bool {
let mut flags = self.0 .0.flags.get();
let ready = flags.contains(Flags::RD_READY);
flags.remove(Flags::RD_READY);
self.0 .0.flags.set(flags);
ready
}
#[inline]
/// Check if read buffer is full
pub fn is_full(&self) -> bool {
if let Some(buf) = self.0.read_buf.take() {
let result = buf.len() >= self.0.pool.get().read_params_high();
self.0.read_buf.set(Some(buf));
if let Some(buf) = self.0 .0.read_buf.take() {
let result = buf.len() >= self.0 .0.pool.get().read_params_high();
self.0 .0.read_buf.set(Some(buf));
result
} else {
false
@ -817,17 +839,17 @@ impl<'a> ReadRef<'a> {
#[inline]
/// Pause read task
pub fn pause(&self, cx: &mut Context<'_>) {
self.0.insert_flags(Flags::RD_PAUSED);
self.0.dispatch_task.register(cx.waker());
self.0 .0.insert_flags(Flags::RD_PAUSED);
self.0 .0.dispatch_task.register(cx.waker());
}
#[inline]
/// Wake read io task if it is paused
pub fn resume(&self) -> bool {
let flags = self.0.flags.get();
let flags = self.0 .0.flags.get();
if flags.contains(Flags::RD_PAUSED) {
self.0.remove_flags(Flags::RD_PAUSED);
self.0.read_task.wake();
self.0 .0.remove_flags(Flags::RD_PAUSED);
self.0 .0.read_task.wake();
true
} else {
false
@ -846,9 +868,9 @@ impl<'a> ReadRef<'a> {
where
U: Decoder,
{
if let Some(mut buf) = self.0.read_buf.take() {
if let Some(mut buf) = self.0 .0.read_buf.take() {
let result = codec.decode(&mut buf);
self.0.read_buf.set(Some(buf));
self.0 .0.read_buf.set(Some(buf));
return result;
}
Ok(None)
@ -862,14 +884,15 @@ impl<'a> ReadRef<'a> {
{
let mut buf = self
.0
.0
.read_buf
.take()
.unwrap_or_else(|| self.0.pool.get().get_read_buf());
.unwrap_or_else(|| self.0 .0.pool.get().get_read_buf());
let res = f(&mut buf);
if buf.is_empty() {
self.0.pool.get().release_read_buf(buf);
self.0 .0.pool.get().release_read_buf(buf);
} else {
self.0.read_buf.set(Some(buf));
self.0 .0.read_buf.set(Some(buf));
}
res
}
@ -883,23 +906,23 @@ impl<'a> ReadRef<'a> {
&self,
cx: &mut Context<'_>,
) -> Result<(), Option<io::Error>> {
let mut flags = self.0.flags.get();
let mut flags = self.0 .0.flags.get();
if !self.0.is_io_open() {
Err(self.0.error.take())
if !self.0 .0.is_io_open() {
Err(self.0 .0.error.take())
} else {
if flags.contains(Flags::RD_BUF_FULL) {
log::trace!("read back-pressure is disabled, wake io task");
flags.remove(Flags::RD_READY | Flags::RD_BUF_FULL);
self.0.flags.set(flags);
self.0.read_task.wake();
self.0 .0.flags.set(flags);
self.0 .0.read_task.wake();
} else if flags.contains(Flags::RD_READY) {
log::trace!("waking up io read task");
flags.remove(Flags::RD_READY);
self.0.flags.set(flags);
self.0.read_task.wake();
self.0 .0.flags.set(flags);
self.0 .0.read_task.wake();
}
self.0.dispatch_task.register(cx.waker());
self.0 .0.dispatch_task.register(cx.waker());
Ok(())
}
}

View file

@ -49,10 +49,12 @@ impl ReadContext {
self.0 .0.dispatch_task.wake();
}
self.0 .0.filter.get().release_read_buf(buf, new_bytes)?;
let close = self.0 .0.filter.get().release_read_buf(buf, new_bytes)?;
if flags.contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters(&self.0)?;
} else if close {
self.0 .0.init_shutdown(None, &self.0);
}
Ok(())
}

View file

@ -170,6 +170,7 @@ impl Future for WriteTask {
}
}
Poll::Ready(Err(WriteReadiness::Timeout(time))) => {
log::trace!("initiate timeout delay for {:?}", time);
if delay.is_none() {
*delay = Some(sleep(time));
}