cleanup ntex-io api

This commit is contained in:
Nikolay Kim 2021-12-20 18:16:17 +06:00
parent a5d734fe47
commit ed57a964b6
30 changed files with 1670 additions and 1726 deletions

View file

@ -1,5 +1,11 @@
# Changes
## [0.1.0-b.2] - 2021-12-20
* Removed `WriteRef` and `ReadRef`
* Better Io/IoRef api separation
## [0.1.0-b.1] - 2021-12-19
* Remove ReadFilter/WriteFilter traits.

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "0.1.0-b.1"
version = "0.1.0-b.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]

View file

@ -7,7 +7,7 @@ use ntex_service::{IntoService, Service};
use ntex_util::future::Either;
use ntex_util::time::{now, Seconds};
use super::{rt::spawn, DispatchItem, IoBoxed, ReadRef, Timer, WriteRef};
use super::{rt::spawn, DispatchItem, IoBoxed, IoRef, Timer};
type Response<U> = <U as Encoder>::Item;
@ -33,8 +33,8 @@ where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Encoder + Decoder,
{
io: IoBoxed,
st: Cell<DispatcherState>,
state: IoBoxed,
timer: Timer,
ka_timeout: Seconds,
ka_updated: Cell<time::Instant>,
@ -90,7 +90,7 @@ where
{
/// Construct new `Dispatcher` instance.
pub fn new<F: IntoService<S>>(
state: IoBoxed,
io: IoBoxed,
codec: U,
service: F,
timer: Timer,
@ -100,13 +100,13 @@ where
// register keepalive timer
let expire = updated + time::Duration::from(ka_timeout);
timer.register(expire, expire, &state);
timer.register(expire, expire, &io);
Dispatcher {
service: service.into_service(),
fut: None,
inner: DispatcherInner {
pool: state.memory_pool().pool(),
pool: io.memory_pool().pool(),
ka_updated: Cell::new(updated),
error: Cell::new(None),
ready_err: Cell::new(false),
@ -116,7 +116,7 @@ where
error: Cell::new(None),
inflight: Cell::new(0),
}),
state,
io,
timer,
ka_timeout,
},
@ -132,10 +132,10 @@ where
// register keepalive timer
let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka());
if timeout.is_zero() {
self.inner.timer.unregister(prev, &self.inner.state);
self.inner.timer.unregister(prev, &self.inner.io);
} else {
let expire = self.inner.ka_updated.get() + time::Duration::from(timeout);
self.inner.timer.register(expire, prev, &self.inner.state);
self.inner.timer.register(expire, prev, &self.inner.io);
}
self.inner.ka_timeout = timeout;
@ -151,7 +151,7 @@ where
///
/// By default disconnect timeout is set to 1 seconds.
pub fn disconnect_timeout(self, val: Seconds) -> Self {
self.inner.state.set_disconnect_timeout(val.into());
self.inner.io.set_disconnect_timeout(val.into());
self
}
}
@ -161,18 +161,18 @@ where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Encoder + Decoder + 'static,
{
fn handle_result(&self, item: Result<S::Response, S::Error>, write: WriteRef<'_>) {
fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoRef) {
self.inflight.set(self.inflight.get() - 1);
match item {
Ok(Some(val)) => match write.encode(val, &self.codec) {
Ok(Some(val)) => match io.encode(val, &self.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Ok(false) => io.enable_write_backpressure(),
Err(err) => self.error.set(Some(DispatcherError::Encoder(err))),
},
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
Ok(None) => return,
}
write.wake_dispatcher();
io.wake_dispatcher();
}
}
@ -186,9 +186,8 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
let slf = &this.inner;
let state = &slf.state;
let read = state.read();
let write = state.write();
let io = &slf.io;
let ioref = io.as_ref();
// handle service response future
if let Some(fut) = this.fut.as_mut().as_pin_mut() {
@ -197,79 +196,58 @@ where
Poll::Ready(item) => {
this.fut.set(None);
slf.shared.inflight.set(slf.shared.inflight.get() - 1);
slf.handle_result(item, write);
slf.handle_result(item, ioref);
}
}
}
// handle memory pool pressure
if slf.pool.poll_ready(cx).is_pending() {
read.pause(cx);
io.pause(cx);
return Poll::Pending;
}
loop {
match slf.st.get() {
DispatcherState::Processing => {
let result = match slf.poll_service(this.service, cx, read) {
Poll::Pending => {
if let Err(err) = read.poll_read_ready(cx) {
log::error!(
"io error while service is in pending state: {:?}",
err
);
return Poll::Ready(Ok(()));
}
return Poll::Pending;
}
Poll::Ready(result) => result,
let result = if let Poll::Ready(result) =
slf.poll_service(this.service, cx, io)
{
result
} else {
return Poll::Pending;
};
let item = match result {
PollService::Ready => {
if !write.is_ready() {
if !io.is_write_ready() {
// instruct write task to notify dispatcher when data is flushed
write.enable_backpressure(Some(cx));
io.enable_write_backpressure(cx);
slf.st.set(DispatcherState::Backpressure);
DispatchItem::WBackPressureEnabled
} else if read.is_ready() {
} else {
// decode incoming bytes if buffer is ready
match read.decode(&slf.shared.codec) {
Ok(Some(el)) => {
match io.poll_read_next(&slf.shared.codec, cx) {
Poll::Ready(Some(Ok(el))) => {
slf.update_keepalive();
DispatchItem::Item(el)
}
Ok(None) => {
log::trace!("not enough data to decode next frame, register dispatch task");
// service is ready, wake io read task
match read.poll_read_ready(cx) {
Ok(()) => {
read.resume();
return Poll::Pending;
}
Err(None) => DispatchItem::Disconnect(None),
Err(Some(err)) => {
DispatchItem::Disconnect(Some(err))
}
}
}
Err(err) => {
Poll::Ready(Some(Err(Either::Left(err)))) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::DecoderError(err)
}
}
} else {
// no new events
match read.poll_read_ready(cx) {
Ok(()) => {
read.resume();
return Poll::Pending;
}
Err(None) => DispatchItem::Disconnect(None),
Err(Some(err)) => {
Poll::Ready(Some(Err(Either::Right(err)))) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::Disconnect(Some(err))
}
Poll::Ready(None) => DispatchItem::Disconnect(None),
Poll::Pending => {
log::trace!("not enough data to decode next frame, register dispatch task");
io.resume();
return Poll::Pending;
}
}
}
}
@ -284,7 +262,7 @@ where
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
Poll::Ready(res) => {
this.fut.set(None);
slf.handle_result(res, write);
slf.handle_result(res, ioref);
}
Poll::Pending => {
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
@ -296,13 +274,13 @@ where
}
// handle write back-pressure
DispatcherState::Backpressure => {
let result = match slf.poll_service(this.service, cx, read) {
let result = match slf.poll_service(this.service, cx, io) {
Poll::Ready(result) => result,
Poll::Pending => return Poll::Pending,
};
let item = match result {
PollService::Ready => {
if write.is_ready() {
if io.is_write_ready() {
slf.st.set(DispatcherState::Processing);
DispatchItem::WBackPressureDisabled
} else {
@ -320,7 +298,7 @@ where
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
Poll::Ready(res) => {
this.fut.set(None);
slf.handle_result(res, write);
slf.handle_result(res, ioref);
}
Poll::Pending => {
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
@ -338,11 +316,14 @@ where
}
if slf.shared.inflight.get() == 0 {
slf.st.set(DispatcherState::Shutdown);
if io.poll_shutdown(cx).is_ready() {
slf.st.set(DispatcherState::Shutdown);
continue;
}
} else {
state.register_dispatcher(cx);
return Poll::Pending;
slf.io.register_dispatcher(cx);
}
return Poll::Pending;
}
// shutdown service
DispatcherState::Shutdown => {
@ -375,23 +356,23 @@ where
fn spawn_service_call(&self, fut: S::Future) {
self.shared.inflight.set(self.shared.inflight.get() + 1);
let st = self.state.get_ref();
let st = self.io.get_ref();
let shared = self.shared.clone();
spawn(async move {
let item = fut.await;
shared.handle_result(item, st.write());
shared.handle_result(item, &st);
});
}
fn handle_result(
&self,
item: Result<Option<<U as Encoder>::Item>, S::Error>,
write: WriteRef<'_>,
io: &IoRef,
) {
match item {
Ok(Some(item)) => match write.encode(item, &self.shared.codec) {
Ok(Some(item)) => match io.encode(item, &self.shared.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Ok(false) => io.enable_write_backpressure(),
Err(err) => self.shared.error.set(Some(DispatcherError::Encoder(err))),
},
Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))),
@ -403,7 +384,7 @@ where
&self,
srv: &S,
cx: &mut Context<'_>,
read: ReadRef<'_>,
io: &IoBoxed,
) -> Poll<PollService<U>> {
match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
@ -428,19 +409,19 @@ where
PollService::ServiceError
}
}
} else if self.state.is_dispatcher_stopped() {
} else if self.io.is_dispatcher_stopped() {
log::trace!("dispatcher is instructed to stop");
self.unregister_keepalive();
// process unhandled data
if let Ok(Some(el)) = read.decode(&self.shared.codec) {
if let Ok(Some(el)) = io.decode(&self.shared.codec) {
PollService::Item(DispatchItem::Item(el))
} else {
self.st.set(DispatcherState::Stop);
// get io error
if let Some(err) = self.state.take_error() {
if let Some(err) = self.io.take_error() {
PollService::Item(DispatchItem::Disconnect(Some(err)))
} else {
PollService::ServiceError
@ -453,7 +434,7 @@ where
// pause io read task
Poll::Pending => {
log::trace!("service is not ready, register dispatch task");
read.pause(cx);
io.pause(cx);
Poll::Pending
}
// handle service readiness error
@ -478,7 +459,7 @@ where
/// check keepalive timeout
fn check_keepalive(&self) {
if self.state.is_keepalive() {
if self.io.is_keepalive() {
log::trace!("keepalive timeout");
if let Some(err) = self.shared.error.take() {
self.shared.error.set(Some(err));
@ -494,11 +475,8 @@ where
let updated = now();
if updated != self.ka_updated.get() {
let ka = time::Duration::from(self.ka());
self.timer.register(
updated + ka,
self.ka_updated.get() + ka,
&self.state,
);
self.timer
.register(updated + ka, self.ka_updated.get() + ka, &self.io);
self.ka_updated.set(updated);
}
}
@ -509,7 +487,7 @@ where
if self.ka_enabled() {
self.timer.unregister(
self.ka_updated.get() + time::Duration::from(self.ka()),
&self.state,
&self.io,
);
}
}
@ -527,7 +505,7 @@ mod tests {
use ntex_util::time::{sleep, Millis};
use crate::testing::IoTest;
use crate::{state::Flags, Io, IoRef, IoStream, WriteRef};
use crate::{io::Flags, Io, IoRef, IoStream};
use super::*;
@ -538,8 +516,8 @@ mod tests {
self.0.flags()
}
fn write(&'_ self) -> WriteRef<'_> {
WriteRef(&self.0)
fn io(&self) -> &IoRef {
&self.0
}
fn close(&self) {
@ -587,7 +565,7 @@ mod tests {
ready_err: Cell::new(false),
st: Cell::new(DispatcherState::Processing),
pool: state.memory_pool().pool(),
state: state.into_boxed(),
io: state.into_boxed(),
shared,
timer,
ka_timeout,
@ -634,7 +612,6 @@ mod tests {
#[ntex::test]
async fn test_sink() {
env_logger::init();
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
@ -658,7 +635,7 @@ mod tests {
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
assert!(st
.write()
.io()
.encode(Bytes::from_static(b"test"), &mut BytesCodec)
.is_ok());
let buf = client.read().await.unwrap();
@ -684,7 +661,7 @@ mod tests {
}),
);
state
.write()
.io()
.encode(
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
&mut BytesCodec,
@ -737,7 +714,7 @@ mod tests {
let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone()));
state
.write()
.io()
.encode(
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
&mut BytesCodec,
@ -821,19 +798,19 @@ mod tests {
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
// response message
assert!(!state.write().is_ready());
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 65536);
assert!(!state.io().is_write_ready());
assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
client.remote_buffer_cap(10240);
sleep(Millis(50)).await;
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 55296);
assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
client.remote_buffer_cap(45056);
sleep(Millis(50)).await;
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 10240);
assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 10240);
// backpressure disabled
assert!(state.write().is_ready());
assert!(state.io().is_write_ready());
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
}

View file

@ -2,8 +2,8 @@ use std::{any, io, task::Context, task::Poll};
use ntex_bytes::BytesMut;
use super::state::{Flags, IoRef};
use super::{Filter, WriteReadiness};
use super::io::Flags;
use super::{Filter, IoRef, WriteReadiness};
pub struct DefaultFilter(IoRef);

698
ntex-io/src/io.rs Normal file
View file

@ -0,0 +1,698 @@
use std::cell::{Cell, RefCell};
use std::task::{Context, Poll};
use std::{fmt, future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
use ntex_bytes::{BytesMut, PoolId, PoolRef};
use ntex_codec::{Decoder, Encoder};
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker, time::Millis};
use super::filter::{DefaultFilter, NullFilter};
use super::tasks::{ReadContext, WriteContext};
use super::{Filter, FilterFactory, Handle, IoStream};
bitflags::bitflags! {
pub struct Flags: u16 {
/// io error occured
const IO_ERR = 0b0000_0000_0000_0001;
/// shuting down filters
const IO_FILTERS = 0b0000_0000_0000_0010;
/// shuting down filters timeout
const IO_FILTERS_TO = 0b0000_0000_0000_0100;
/// shutdown io tasks
const IO_SHUTDOWN = 0b0000_0000_0000_1000;
/// io object is closed
const IO_CLOSED = 0b0000_0000_0001_0000;
/// pause io read
const RD_PAUSED = 0b0000_0000_0010_0000;
/// new data is available
const RD_READY = 0b0000_0000_0100_0000;
/// read buffer is full
const RD_BUF_FULL = 0b0000_0000_1000_0000;
/// wait write completion
const WR_WAIT = 0b0000_0001_0000_0000;
/// write buffer is full
const WR_BACKPRESSURE = 0b0000_0010_0000_0000;
/// dispatcher is marked stopped
const DSP_STOP = 0b0001_0000_0000_0000;
/// keep-alive timeout occured
const DSP_KEEPALIVE = 0b0010_0000_0000_0000;
/// dispatcher returned error
const DSP_ERR = 0b0100_0000_0000_0000;
}
}
enum FilterItem<F> {
Boxed(Box<dyn Filter>),
Ptr(*mut F),
}
pub struct Io<F = DefaultFilter>(pub(super) IoRef, FilterItem<F>);
#[derive(Clone)]
pub struct IoRef(pub(super) Rc<IoState>);
pub(crate) struct IoState {
pub(super) flags: Cell<Flags>,
pub(super) pool: Cell<PoolRef>,
pub(super) disconnect_timeout: Cell<Millis>,
pub(super) error: Cell<Option<io::Error>>,
pub(super) read_task: LocalWaker,
pub(super) write_task: LocalWaker,
pub(super) dispatch_task: LocalWaker,
pub(super) read_buf: Cell<Option<BytesMut>>,
pub(super) write_buf: Cell<Option<BytesMut>>,
pub(super) filter: Cell<&'static dyn Filter>,
pub(super) handle: Cell<Option<Box<dyn Handle>>>,
pub(super) on_disconnect: RefCell<Vec<Option<LocalWaker>>>,
}
impl IoState {
#[inline]
pub(super) fn insert_flags(&self, f: Flags) {
let mut flags = self.flags.get();
flags.insert(f);
self.flags.set(flags);
}
#[inline]
pub(super) fn remove_flags(&self, f: Flags) {
let mut flags = self.flags.get();
flags.remove(f);
self.flags.set(flags);
}
#[inline]
pub(super) fn notify_keepalive(&self) {
let mut flags = self.flags.get();
if !flags.contains(Flags::DSP_KEEPALIVE) {
flags.insert(Flags::DSP_KEEPALIVE);
self.flags.set(flags);
self.dispatch_task.wake();
}
}
#[inline]
pub(super) fn notify_disconnect(&self) {
let mut on_disconnect = self.on_disconnect.borrow_mut();
for item in &mut *on_disconnect {
if let Some(waker) = item.take() {
waker.wake();
}
}
}
#[inline]
pub(super) fn is_io_open(&self) -> bool {
!self.flags.get().intersects(
Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_SHUTDOWN | Flags::IO_CLOSED,
)
}
#[inline]
pub(super) fn set_error(&self, err: Option<io::Error>) {
if err.is_some() {
self.error.set(err);
}
self.read_task.wake();
self.write_task.wake();
self.dispatch_task.wake();
self.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.notify_disconnect();
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, cx: Option<&mut Context<'_>>, st: &IoRef) {
let flags = self.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?}", flags);
self.insert_flags(Flags::IO_FILTERS);
if let Err(err) = self.shutdown_filters(st) {
self.error.set(Some(err));
}
self.read_task.wake();
self.write_task.wake();
if let Some(cx) = cx {
self.dispatch_task.register(cx.waker());
}
}
}
#[inline]
pub(super) fn shutdown_filters(&self, st: &IoRef) -> Result<(), io::Error> {
let mut flags = self.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let result = match self.filter.get().shutdown(st) {
Poll::Pending => return Ok(()),
Poll::Ready(Ok(())) => {
flags.insert(Flags::IO_SHUTDOWN);
Ok(())
}
Poll::Ready(Err(err)) => {
flags.insert(Flags::IO_ERR);
self.dispatch_task.wake();
Err(err)
}
};
self.flags.set(flags);
self.read_task.wake();
self.write_task.wake();
result
} else {
Ok(())
}
}
}
impl Eq for IoState {}
impl PartialEq for IoState {
#[inline]
fn eq(&self, other: &Self) -> bool {
ptr::eq(self, other)
}
}
impl hash::Hash for IoState {
#[inline]
fn hash<H: hash::Hasher>(&self, state: &mut H) {
(self as *const _ as usize).hash(state);
}
}
impl Drop for IoState {
#[inline]
fn drop(&mut self) {
if let Some(buf) = self.read_buf.take() {
self.pool.get().release_read_buf(buf);
}
if let Some(buf) = self.write_buf.take() {
self.pool.get().release_write_buf(buf);
}
}
}
impl Io {
#[inline]
/// Create `State` instance
pub fn new<I: IoStream>(io: I) -> Self {
Self::with_memory_pool(io, PoolId::DEFAULT.pool_ref())
}
#[inline]
/// Create `State` instance with specific memory pool.
pub fn with_memory_pool<I: IoStream>(io: I, pool: PoolRef) -> Self {
let inner = Rc::new(IoState {
pool: Cell::new(pool),
flags: Cell::new(Flags::empty()),
error: Cell::new(None),
disconnect_timeout: Cell::new(Millis::ONE_SEC),
dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(),
write_task: LocalWaker::new(),
read_buf: Cell::new(None),
write_buf: Cell::new(None),
filter: Cell::new(NullFilter::get()),
handle: Cell::new(None),
on_disconnect: RefCell::new(Vec::new()),
});
let filter = Box::new(DefaultFilter::new(IoRef(inner.clone())));
let filter_ref: &'static dyn Filter = unsafe {
let filter: &dyn Filter = filter.as_ref();
std::mem::transmute(filter)
};
inner.filter.replace(filter_ref);
let io_ref = IoRef(inner);
// start io tasks
let hnd = io.start(ReadContext(io_ref.clone()), WriteContext(io_ref.clone()));
io_ref.0.handle.set(hnd);
Io(io_ref, FilterItem::Ptr(Box::into_raw(filter)))
}
}
impl<F> Io<F> {
#[inline]
/// Set memory pool
pub fn set_memory_pool(&self, pool: PoolRef) {
if let Some(mut buf) = self.0 .0.read_buf.take() {
pool.move_in(&mut buf);
self.0 .0.read_buf.set(Some(buf));
}
if let Some(mut buf) = self.0 .0.write_buf.take() {
pool.move_in(&mut buf);
self.0 .0.write_buf.set(Some(buf));
}
self.0 .0.pool.set(pool);
}
#[inline]
/// Set io disconnect timeout in secs
pub fn set_disconnect_timeout(&self, timeout: Millis) {
self.0 .0.disconnect_timeout.set(timeout);
}
}
impl<F> Io<F> {
#[inline]
#[doc(hidden)]
/// Get current state flags
pub fn flags(&self) -> Flags {
self.0 .0.flags.get()
}
#[inline]
#[allow(clippy::should_implement_trait)]
/// Get IoRef reference
pub fn as_ref(&self) -> &IoRef {
&self.0
}
#[inline]
/// Get instance of IoRef
pub fn get_ref(&self) -> IoRef {
self.0.clone()
}
#[inline]
/// Check if dispatcher marked stopped
pub fn is_dispatcher_stopped(&self) -> bool {
self.flags().contains(Flags::DSP_STOP)
}
#[inline]
/// Register dispatcher task
pub fn register_dispatcher(&self, cx: &mut Context<'_>) {
self.0 .0.dispatch_task.register(cx.waker());
}
#[inline]
/// Reset keep-alive error
pub fn reset_keepalive(&self) {
self.0 .0.remove_flags(Flags::DSP_KEEPALIVE)
}
}
impl<F: Filter> Io<F> {
#[inline]
/// Get referece to filter
pub fn filter(&self) -> &F {
if let FilterItem::Ptr(p) = self.1 {
if let Some(r) = unsafe { p.as_ref() } {
return r;
}
}
panic!()
}
#[inline]
pub fn into_boxed(mut self) -> crate::IoBoxed
where
F: 'static,
{
// get current filter
let filter = unsafe {
let item = mem::replace(&mut self.1, FilterItem::Ptr(std::ptr::null_mut()));
let filter: Box<dyn Filter> = match item {
FilterItem::Boxed(b) => b,
FilterItem::Ptr(p) => Box::new(*Box::from_raw(p)),
};
let filter_ref: &'static dyn Filter = {
let filter: &dyn Filter = filter.as_ref();
std::mem::transmute(filter)
};
self.0 .0.filter.replace(filter_ref);
filter
};
Io(self.0.clone(), FilterItem::Boxed(filter))
}
#[inline]
pub fn add_filter<T>(self, factory: T) -> T::Future
where
T: FilterFactory<F>,
{
factory.create(self)
}
#[inline]
pub fn map_filter<T, U, E>(mut self, map: U) -> Result<Io<T>, E>
where
T: Filter,
U: FnOnce(F) -> Result<T, E>,
{
// replace current filter
let filter = unsafe {
let item = mem::replace(&mut self.1, FilterItem::Ptr(std::ptr::null_mut()));
let filter = match item {
FilterItem::Boxed(_) => panic!(),
FilterItem::Ptr(p) => {
assert!(!p.is_null());
Box::new(map(*Box::from_raw(p))?)
}
};
let filter_ref: &'static dyn Filter = {
let filter: &dyn Filter = filter.as_ref();
std::mem::transmute(filter)
};
self.0 .0.filter.replace(filter_ref);
filter
};
Ok(Io(self.0.clone(), FilterItem::Ptr(Box::into_raw(filter))))
}
}
impl<F> Io<F> {
#[inline]
/// Read incoming io stream and decode codec item.
pub async fn next<U>(
&self,
codec: &U,
) -> Option<Result<U::Item, Either<U::Error, io::Error>>>
where
U: Decoder,
{
poll_fn(|cx| self.poll_read_next(codec, cx)).await
}
#[inline]
/// Encode item, send to a peer
pub async fn send<U>(
&self,
item: U::Item,
codec: &U,
) -> Result<(), Either<U::Error, io::Error>>
where
U: Encoder,
{
let filter = self.filter();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
codec.encode(item, &mut buf).map_err(Either::Left)?;
filter.release_write_buf(buf).map_err(Either::Right)?;
if is_write_sleep {
self.0 .0.write_task.wake();
}
poll_fn(|cx| self.poll_write_ready(cx, true))
.await
.map_err(Either::Right)?;
Ok(())
}
#[inline]
/// Wake write task and instruct to write data.
///
/// This is async version of .poll_write_ready() method.
pub async fn write_ready(&self, full: bool) -> Result<(), io::Error> {
poll_fn(|cx| self.poll_write_ready(cx, full)).await
}
#[inline]
/// Shut down connection
pub async fn shutdown(&self) -> Result<(), io::Error> {
poll_fn(|cx| self.poll_shutdown(cx)).await
}
}
impl<F> Io<F> {
#[inline]
/// Wake write task and instruct to write data.
///
/// If full is true then wake up dispatcher when all data is flushed
/// otherwise wake up when size of write buffer is lower than
/// buffer max size.
pub fn poll_write_ready(
&self,
cx: &mut Context<'_>,
full: bool,
) -> Poll<io::Result<()>> {
// check io error
if !self.0 .0.is_io_open() {
return Poll::Ready(Err(self.0 .0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "disconnected")
})));
}
if let Some(buf) = self.0 .0.write_buf.take() {
let len = buf.len();
if len != 0 {
self.0 .0.write_buf.set(Some(buf));
if full {
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= self.0.memory_pool().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else {
self.0 .0.remove_flags(Flags::WR_BACKPRESSURE);
}
}
}
Poll::Ready(Ok(()))
}
#[inline]
/// Wake read task and instruct to read more data
///
/// Read task is awake only if back-pressure is enabled
/// otherwise it is already awake. Buffer read status gets clean up.
pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Option<io::Result<()>>> {
if !self.0 .0.is_io_open() {
if let Some(err) = self.0 .0.error.take() {
Poll::Ready(Some(Err(err)))
} else {
Poll::Ready(None)
}
} else {
self.0 .0.dispatch_task.register(cx.waker());
let mut flags = self.0 .0.flags.get();
let ready = flags.contains(Flags::RD_READY);
if flags.contains(Flags::RD_BUF_FULL) {
log::trace!("read back-pressure is disabled, wake io task");
flags.remove(Flags::RD_READY | Flags::RD_BUF_FULL);
self.0 .0.read_task.wake();
self.0 .0.flags.set(flags);
if ready {
Poll::Ready(Some(Ok(())))
} else {
Poll::Pending
}
} else if ready {
log::trace!("waking up io read task");
flags.remove(Flags::RD_READY);
self.0 .0.flags.set(flags);
self.0 .0.read_task.wake();
Poll::Ready(Some(Ok(())))
} else {
Poll::Pending
}
}
}
#[inline]
#[allow(clippy::type_complexity)]
pub fn poll_read_next<U>(
&self,
codec: &U,
cx: &mut Context<'_>,
) -> Poll<Option<Result<U::Item, Either<U::Error, io::Error>>>>
where
U: Decoder,
{
match self.decode(codec) {
Ok(Some(el)) => Poll::Ready(Some(Ok(el))),
Ok(None) => match self.poll_read_ready(cx) {
Poll::Pending | Poll::Ready(Some(Ok(()))) => Poll::Pending,
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(Either::Right(e)))),
Poll::Ready(None) => Poll::Ready(None),
},
Err(err) => Poll::Ready(Some(Err(Either::Left(err)))),
}
}
#[inline]
/// Shut down connection
pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let flags = self.flags();
if flags.intersects(Flags::IO_ERR | Flags::IO_CLOSED) {
Poll::Ready(Ok(()))
} else {
if !flags.contains(Flags::IO_FILTERS) {
self.0 .0.init_shutdown(Some(cx), self.as_ref());
}
if let Some(err) = self.0 .0.error.take() {
Poll::Ready(Err(err))
} else {
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
}
}
}
#[inline]
/// Pause read task
pub fn pause(&self, cx: &mut Context<'_>) {
self.0 .0.insert_flags(Flags::RD_PAUSED);
self.0 .0.dispatch_task.register(cx.waker());
}
#[inline]
/// Wake read io task if it is paused
pub fn resume(&self) -> bool {
let flags = self.0 .0.flags.get();
if flags.contains(Flags::RD_PAUSED) {
self.0 .0.remove_flags(Flags::RD_PAUSED);
self.0 .0.read_task.wake();
true
} else {
false
}
}
#[inline]
/// Wait until write task flushes data to io stream
///
/// Write task must be waken up separately.
pub fn enable_write_backpressure(&self, cx: &mut Context<'_>) {
log::trace!("enable write back-pressure for dispatcher");
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
}
}
impl<F> Drop for Io<F> {
fn drop(&mut self) {
if let FilterItem::Ptr(p) = self.1 {
if p.is_null() {
return;
}
log::trace!(
"io is dropped, force stopping io streams {:?}",
self.0.flags()
);
self.force_close();
self.0 .0.filter.set(NullFilter::get());
let _ = mem::replace(&mut self.1, FilterItem::Ptr(std::ptr::null_mut()));
unsafe { Box::from_raw(p) };
} else {
log::trace!(
"io is dropped, force stopping io streams {:?}",
self.0.flags()
);
self.force_close();
self.0 .0.filter.set(NullFilter::get());
}
}
}
impl fmt::Debug for Io {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Io")
.field("open", &!self.is_closed())
.finish()
}
}
impl<F> Deref for Io<F> {
type Target = IoRef;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// OnDisconnect future resolves when socket get disconnected
#[must_use = "OnDisconnect do nothing unless polled"]
pub struct OnDisconnect {
token: usize,
inner: Rc<IoState>,
}
impl OnDisconnect {
pub(super) fn new(inner: Rc<IoState>) -> Self {
Self::new_inner(inner.flags.get().contains(Flags::IO_ERR), inner)
}
fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
let token = if disconnected {
usize::MAX
} else {
let mut on_disconnect = inner.on_disconnect.borrow_mut();
let token = on_disconnect.len();
on_disconnect.push(Some(LocalWaker::default()));
drop(on_disconnect);
token
};
Self { token, inner }
}
#[inline]
/// Check if connection is disconnected
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
if self.token == usize::MAX {
Poll::Ready(())
} else {
let on_disconnect = self.inner.on_disconnect.borrow();
if on_disconnect[self.token].is_some() {
on_disconnect[self.token]
.as_ref()
.unwrap()
.register(cx.waker());
Poll::Pending
} else {
Poll::Ready(())
}
}
}
}
impl Clone for OnDisconnect {
fn clone(&self) -> Self {
if self.token == usize::MAX {
OnDisconnect::new_inner(true, self.inner.clone())
} else {
OnDisconnect::new_inner(false, self.inner.clone())
}
}
}
impl Future for OnDisconnect {
type Output = ();
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.poll_ready(cx)
}
}
impl Drop for OnDisconnect {
fn drop(&mut self) {
if self.token != usize::MAX {
self.inner.on_disconnect.borrow_mut()[self.token].take();
}
}
}

558
ntex-io/src/ioref.rs Normal file
View file

@ -0,0 +1,558 @@
use std::{any, fmt, io};
use ntex_bytes::{BytesMut, PoolRef};
use ntex_codec::{Decoder, Encoder};
use super::io::{Flags, IoRef, OnDisconnect};
use super::{types, Filter};
impl IoRef {
#[inline]
#[doc(hidden)]
/// Get current state flags
pub fn flags(&self) -> Flags {
self.0.flags.get()
}
#[inline]
/// Set flags
pub(crate) fn set_flags(&self, flags: Flags) {
self.0.flags.set(flags)
}
#[inline]
/// Get memory pool
pub(crate) fn filter(&self) -> &dyn Filter {
self.0.filter.get()
}
#[inline]
/// Get memory pool
pub fn memory_pool(&self) -> PoolRef {
self.0.pool.get()
}
#[inline]
/// Check if io is still active
pub fn is_io_open(&self) -> bool {
self.0.is_io_open()
}
#[inline]
/// Check if keep-alive timeout occured
pub fn is_keepalive(&self) -> bool {
self.0.flags.get().contains(Flags::DSP_KEEPALIVE)
}
#[inline]
/// Check if io stream is closed
pub fn is_closed(&self) -> bool {
self.0.flags.get().intersects(
Flags::IO_ERR
| Flags::IO_SHUTDOWN
| Flags::IO_CLOSED
| Flags::IO_FILTERS
| Flags::DSP_STOP,
)
}
#[inline]
/// Take io error if any occured
pub fn take_error(&self) -> Option<io::Error> {
self.0.error.take()
}
#[inline]
/// Wake dispatcher task
pub fn wake_dispatcher(&self) {
self.0.dispatch_task.wake();
}
#[inline]
/// Mark dispatcher as stopped
pub fn stop_dispatcher(&self) {
self.0.insert_flags(Flags::DSP_STOP);
}
#[inline]
/// Gracefully close connection
///
/// First stop dispatcher, then dispatcher stops io tasks
pub fn close(&self) {
self.0.insert_flags(Flags::DSP_STOP);
self.0.dispatch_task.wake();
}
#[inline]
/// Force close connection
///
/// Dispatcher does not wait for uncompleted responses, but flushes io buffers.
pub fn force_close(&self) {
log::trace!("force close framed object");
self.0.insert_flags(Flags::DSP_STOP | Flags::IO_SHUTDOWN);
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.wake();
}
#[inline]
/// Notify when io stream get disconnected
pub fn on_disconnect(&self) -> OnDisconnect {
OnDisconnect::new(self.0.clone())
}
#[inline]
/// Query specific data
pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
types::QueryItem::new(item)
} else {
types::QueryItem::empty()
}
}
#[inline]
/// Check if write task is ready
pub fn is_write_ready(&self) -> bool {
!self.0.flags.get().contains(Flags::WR_BACKPRESSURE)
}
#[inline]
/// Check if read buffer has new data
pub fn is_read_ready(&self) -> bool {
self.0.flags.get().contains(Flags::RD_READY)
}
#[inline]
/// Check if write buffer is full
pub fn is_write_buf_full(&self) -> bool {
if let Some(buf) = self.0.read_buf.take() {
let hw = self.memory_pool().write_params_high();
let result = buf.len() >= hw;
self.0.write_buf.set(Some(buf));
result
} else {
false
}
}
#[inline]
/// Check if read buffer is full
pub fn is_read_buf_full(&self) -> bool {
if let Some(buf) = self.0.read_buf.take() {
let result = buf.len() >= self.memory_pool().read_params_high();
self.0.read_buf.set(Some(buf));
result
} else {
false
}
}
#[inline]
/// Wait until write task flushes data to io stream
///
/// Write task must be waken up separately.
pub fn enable_write_backpressure(&self) {
log::trace!("enable write back-pressure");
self.0.insert_flags(Flags::WR_BACKPRESSURE);
}
#[inline]
/// Get mut access to write buffer
pub fn with_write_buf<F, R>(&self, f: F) -> Result<R, io::Error>
where
F: FnOnce(&mut BytesMut) -> R,
{
let filter = self.0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
if buf.is_empty() {
self.0.write_task.wake();
}
let result = f(&mut buf);
filter.release_write_buf(buf)?;
Ok(result)
}
#[inline]
/// Get mut access to read buffer
pub fn with_read_buf<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let mut buf = self
.0
.read_buf
.take()
.unwrap_or_else(|| self.memory_pool().get_read_buf());
let res = f(&mut buf);
if buf.is_empty() {
self.memory_pool().release_read_buf(buf);
} else {
self.0.read_buf.set(Some(buf));
}
res
}
#[inline]
/// Encode and write item to a buffer and wake up write task
///
/// Returns write buffer state, false is returned if write buffer if full.
pub fn encode<U>(
&self,
item: U::Item,
codec: &U,
) -> Result<bool, <U as Encoder>::Error>
where
U: Encoder,
{
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let filter = self.0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
let (hw, lw) = self.memory_pool().write_params().unpack();
// make sure we've got room
let remaining = buf.capacity() - buf.len();
if remaining < lw {
buf.reserve(hw - remaining);
}
// encode item and wake write task
let result = codec.encode(item, &mut buf).map(|_| {
if is_write_sleep {
self.0.write_task.wake();
}
buf.len() < hw
});
if let Err(err) = filter.release_write_buf(buf) {
self.0.set_error(Some(err));
}
result
} else {
Ok(true)
}
}
#[inline]
/// Attempts to decode a frame from the read buffer
///
/// Read buffer ready state gets cleanup if decoder cannot
/// decode any frame.
pub fn decode<U>(
&self,
codec: &U,
) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
where
U: Decoder,
{
if let Some(mut buf) = self.0.read_buf.take() {
let result = codec.decode(&mut buf);
self.0.read_buf.set(Some(buf));
return result;
}
Ok(None)
}
#[inline]
/// Write bytes to a buffer and wake up write task
///
/// Returns write buffer state, false is returned if write buffer if full.
pub fn write(&self, src: &[u8]) -> Result<bool, io::Error> {
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let filter = self.0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
// write and wake write task
buf.extend_from_slice(src);
let result = buf.len() < self.memory_pool().write_params_high();
if is_write_sleep {
self.0.write_task.wake();
}
if let Err(err) = filter.release_write_buf(buf) {
self.0.set_error(Some(err));
}
Ok(result)
} else {
Ok(true)
}
}
}
impl fmt::Debug for IoRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IoRef")
.field("open", &!self.is_closed())
.finish()
}
}
#[cfg(test)]
mod tests {
use std::{cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
use ntex_bytes::Bytes;
use ntex_codec::BytesCodec;
use ntex_util::future::{lazy, poll_fn, Ready};
use ntex_util::time::{sleep, Millis};
use super::*;
use crate::testing::IoTest;
use crate::{Filter, FilterFactory, Io, WriteReadiness};
const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
#[ntex::test]
async fn utils() {
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write(TEXT);
let state = Io::new(server);
assert!(!state.is_read_buf_full());
assert!(!state.is_write_buf_full());
let msg = state.next(&BytesCodec).await.unwrap().unwrap();
assert_eq!(msg, Bytes::from_static(BIN));
let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await;
assert!(res.is_pending());
client.write(TEXT);
sleep(Millis(50)).await;
let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await;
if let Poll::Ready(msg) = res {
assert_eq!(msg.unwrap().unwrap(), Bytes::from_static(BIN));
}
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
let msg = state.next(&BytesCodec).await;
assert!(msg.unwrap().is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::DSP_STOP));
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
let state = Io::new(server);
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await;
if let Poll::Ready(msg) = res {
assert!(msg.unwrap().is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::DSP_STOP));
}
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
let state = Io::new(server);
state
.send(Bytes::from_static(b"test"), &BytesCodec)
.await
.unwrap();
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test"));
client.write_error(io::Error::new(io::ErrorKind::Other, "err"));
let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
assert!(res.is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::DSP_STOP));
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
let state = Io::new(server);
state.force_close();
assert!(state.flags().contains(Flags::DSP_STOP));
assert!(state.flags().contains(Flags::IO_SHUTDOWN));
}
#[ntex::test]
async fn on_disconnect() {
let (client, server) = IoTest::create();
let state = Io::new(server);
let mut waiter = state.on_disconnect();
assert_eq!(
lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
Poll::Pending
);
let mut waiter2 = waiter.clone();
assert_eq!(
lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
Poll::Pending
);
client.close().await;
assert_eq!(waiter.await, ());
assert_eq!(waiter2.await, ());
let mut waiter = state.on_disconnect();
assert_eq!(
lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
Poll::Ready(())
);
let (client, server) = IoTest::create();
let state = Io::new(server);
let mut waiter = state.on_disconnect();
assert_eq!(
lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
Poll::Pending
);
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
assert_eq!(waiter.await, ());
}
struct Counter<F> {
inner: F,
in_bytes: Rc<Cell<usize>>,
out_bytes: Rc<Cell<usize>>,
}
impl<F: Filter> Filter for Counter<F> {
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn query(&self, _: std::any::TypeId) -> Option<Box<dyn std::any::Any>> {
None
}
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.inner.poll_read_ready(cx)
}
fn closed(&self, err: Option<io::Error>) {
self.inner.closed(err)
}
fn get_read_buf(&self) -> Option<BytesMut> {
self.inner.get_read_buf()
}
fn release_read_buf(
&self,
buf: BytesMut,
new_bytes: usize,
) -> Result<(), io::Error> {
self.in_bytes.set(self.in_bytes.get() + new_bytes);
self.inner.release_read_buf(buf, new_bytes)
}
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
self.inner.poll_write_ready(cx)
}
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.get_write_buf() {
self.out_bytes.set(self.out_bytes.get() - buf.len());
Some(buf)
} else {
None
}
}
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
self.out_bytes.set(self.out_bytes.get() + buf.len());
self.inner.release_write_buf(buf)
}
}
struct CounterFactory(Rc<Cell<usize>>, Rc<Cell<usize>>);
impl<F: Filter> FilterFactory<F> for CounterFactory {
type Filter = Counter<F>;
type Error = ();
type Future = Ready<Io<Counter<F>>, Self::Error>;
fn create(self, io: Io<F>) -> Self::Future {
let in_bytes = self.0.clone();
let out_bytes = self.1.clone();
Ready::Ok(
io.map_filter(|inner| {
Ok::<_, ()>(Counter {
inner,
in_bytes,
out_bytes,
})
})
.unwrap(),
)
}
}
#[ntex::test]
async fn filter() {
let in_bytes = Rc::new(Cell::new(0));
let out_bytes = Rc::new(Cell::new(0));
let factory = CounterFactory(in_bytes.clone(), out_bytes.clone());
let (client, server) = IoTest::create();
let state = Io::new(server).add_filter(factory).await.unwrap();
client.remote_buffer_cap(1024);
client.write(TEXT);
let msg = state.next(&BytesCodec).await.unwrap().unwrap();
assert_eq!(msg, Bytes::from_static(BIN));
state
.send(Bytes::from_static(b"test"), &BytesCodec)
.await
.unwrap();
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test"));
assert_eq!(in_bytes.get(), BIN.len());
assert_eq!(out_bytes.get(), 4);
}
#[ntex::test]
async fn boxed_filter() {
let in_bytes = Rc::new(Cell::new(0));
let out_bytes = Rc::new(Cell::new(0));
let (client, server) = IoTest::create();
let state = Io::new(server)
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
.await
.unwrap()
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
.await
.unwrap();
let state = state.into_boxed();
client.remote_buffer_cap(1024);
client.write(TEXT);
let msg = state.next(&BytesCodec).await.unwrap().unwrap();
assert_eq!(msg, Bytes::from_static(BIN));
state
.send(Bytes::from_static(b"test"), &BytesCodec)
.await
.unwrap();
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test"));
assert_eq!(in_bytes.get(), BIN.len() * 2);
assert_eq!(out_bytes.get(), 8);
// refs
assert_eq!(Rc::strong_count(&in_bytes), 3);
drop(state);
assert_eq!(Rc::strong_count(&in_bytes), 1);
}
}

View file

@ -1,11 +1,15 @@
use std::{any::Any, any::TypeId, fmt, future::Future, io, task::Context, task::Poll};
use std::{
any::Any, any::TypeId, fmt, future::Future, io::Error as IoError, task::Context,
task::Poll,
};
pub mod testing;
pub mod types;
mod dispatcher;
mod filter;
mod state;
mod io;
mod ioref;
mod tasks;
mod time;
mod utils;
@ -21,7 +25,7 @@ use ntex_util::time::Millis;
pub use self::dispatcher::Dispatcher;
pub use self::filter::DefaultFilter;
pub use self::state::{Io, IoRef, OnDisconnect, ReadRef, WriteRef};
pub use self::io::{Io, IoRef, OnDisconnect};
pub use self::tasks::{ReadContext, WriteContext};
pub use self::time::Timer;
@ -37,9 +41,9 @@ pub enum WriteReadiness {
}
pub trait Filter: 'static {
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>>;
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), IoError>>;
fn closed(&self, err: Option<io::Error>);
fn closed(&self, err: Option<IoError>);
fn query(&self, id: TypeId) -> Option<Box<dyn Any>>;
@ -52,9 +56,9 @@ pub trait Filter: 'static {
fn get_write_buf(&self) -> Option<BytesMut>;
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error>;
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), IoError>;
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>;
fn release_write_buf(&self, buf: BytesMut) -> Result<(), IoError>;
}
pub trait FilterFactory<F: Filter>: Sized {
@ -88,7 +92,7 @@ pub enum DispatchItem<U: Encoder + Decoder> {
/// Encoder parse error
EncoderError(<U as Encoder>::Error),
/// Socket is disconnected
Disconnect(Option<io::Error>),
Disconnect(Option<IoError>),
}
impl<U> fmt::Debug for DispatchItem<U>
@ -134,6 +138,7 @@ pub mod rt {
mod tests {
use super::*;
use ntex_codec::BytesCodec;
use std::io;
#[test]
fn test_fmt() {

File diff suppressed because it is too large Load diff

View file

@ -2,7 +2,7 @@ use std::{io, task::Context, task::Poll};
use ntex_bytes::{BytesMut, PoolRef};
use super::{state::Flags, IoRef, WriteReadiness};
use super::{io::Flags, IoRef, WriteReadiness};
pub struct ReadContext(pub(super) IoRef);
@ -45,6 +45,7 @@ impl ReadContext {
flags.insert(Flags::RD_READY);
self.0.set_flags(flags);
self.0 .0.dispatch_task.wake();
log::trace!("new {} bytes available, wakeup dispatcher", new_bytes);
}
self.0.filter().release_read_buf(buf, new_bytes)?;

View file

@ -409,7 +409,10 @@ impl Future for ReadTask {
match io.poll_read_buf(cx, &mut buf) {
Poll::Pending => {
log::trace!("no more data in io stream");
log::trace!(
"no more data in io stream, read: {:?}",
new_bytes
);
break;
}
Poll::Ready(Ok(n)) => {

View file

@ -4,15 +4,14 @@ use std::{
use ntex_util::time::{now, sleep, Millis};
use crate::rt::spawn;
use crate::state::{IoRef, IoStateInner};
use crate::{io::IoState, rt::spawn, IoRef};
pub struct Timer(Rc<RefCell<Inner>>);
struct Inner {
running: bool,
resolution: Millis,
notifications: BTreeMap<Instant, HashSet<Rc<IoStateInner>, fxhash::FxBuildHasher>>,
notifications: BTreeMap<Instant, HashSet<Rc<IoState>, fxhash::FxBuildHasher>>,
}
impl Inner {

View file

@ -355,21 +355,17 @@ impl<F: Filter> AsyncRead for Io<F> {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let read = self.read();
let len = read.with_buf(|src| {
let len = self.with_read_buf(|src| {
let len = cmp::min(src.len(), buf.capacity());
buf.put_slice(&src.split_to(len));
len
});
if len == 0 {
match read.poll_read_ready(cx) {
Ok(()) => Poll::Pending,
Err(Some(e)) => Poll::Ready(Err(e)),
Err(None) => Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"disconnected",
))),
match self.poll_read_ready(cx) {
Poll::Pending | Poll::Ready(Some(Ok(()))) => Poll::Pending,
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(e)),
Poll::Ready(None) => Poll::Ready(Ok(())),
}
} else {
Poll::Ready(Ok(()))
@ -383,18 +379,18 @@ impl<F: Filter> AsyncWrite for Io<F> {
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.write().write(buf).map(|_| buf.len()))
Poll::Ready(self.write(buf).map(|_| buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.write().poll_write_ready(cx, false)
self.poll_write_ready(cx, false)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.0.poll_shutdown(cx)
Io::poll_shutdown(&*self, cx)
}
}
@ -404,21 +400,17 @@ impl AsyncRead for IoBoxed {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let read = self.read();
let len = read.with_buf(|src| {
let len = self.with_read_buf(|src| {
let len = cmp::min(src.len(), buf.capacity());
buf.put_slice(&src.split_to(len));
len
});
if len == 0 {
match read.poll_read_ready(cx) {
Ok(()) => Poll::Pending,
Err(Some(e)) => Poll::Ready(Err(e)),
Err(None) => Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"disconnected",
))),
match self.poll_read_ready(cx) {
Poll::Pending | Poll::Ready(Some(Ok(()))) => Poll::Pending,
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(e)),
Poll::Ready(None) => Poll::Ready(Ok(())),
}
} else {
Poll::Ready(Ok(()))
@ -432,18 +424,18 @@ impl AsyncWrite for IoBoxed {
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.write().write(buf).map(|_| buf.len()))
Poll::Ready(self.write(buf).map(|_| buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.write().poll_write_ready(cx, false)
self.poll_write_ready(cx, false)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.0.poll_shutdown(cx)
Self::poll_shutdown(&*self, cx)
}
}