mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-04 13:27:39 +03:00
cleanup ntex-io api
This commit is contained in:
parent
a5d734fe47
commit
ed57a964b6
30 changed files with 1670 additions and 1726 deletions
|
@ -1,5 +1,11 @@
|
||||||
# Changes
|
# Changes
|
||||||
|
|
||||||
|
## [0.1.0-b.2] - 2021-12-20
|
||||||
|
|
||||||
|
* Removed `WriteRef` and `ReadRef`
|
||||||
|
|
||||||
|
* Better Io/IoRef api separation
|
||||||
|
|
||||||
## [0.1.0-b.1] - 2021-12-19
|
## [0.1.0-b.1] - 2021-12-19
|
||||||
|
|
||||||
* Remove ReadFilter/WriteFilter traits.
|
* Remove ReadFilter/WriteFilter traits.
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "ntex-io"
|
name = "ntex-io"
|
||||||
version = "0.1.0-b.1"
|
version = "0.1.0-b.2"
|
||||||
authors = ["ntex contributors <team@ntex.rs>"]
|
authors = ["ntex contributors <team@ntex.rs>"]
|
||||||
description = "Utilities for encoding and decoding frames"
|
description = "Utilities for encoding and decoding frames"
|
||||||
keywords = ["network", "framework", "async", "futures"]
|
keywords = ["network", "framework", "async", "futures"]
|
||||||
|
|
|
@ -7,7 +7,7 @@ use ntex_service::{IntoService, Service};
|
||||||
use ntex_util::future::Either;
|
use ntex_util::future::Either;
|
||||||
use ntex_util::time::{now, Seconds};
|
use ntex_util::time::{now, Seconds};
|
||||||
|
|
||||||
use super::{rt::spawn, DispatchItem, IoBoxed, ReadRef, Timer, WriteRef};
|
use super::{rt::spawn, DispatchItem, IoBoxed, IoRef, Timer};
|
||||||
|
|
||||||
type Response<U> = <U as Encoder>::Item;
|
type Response<U> = <U as Encoder>::Item;
|
||||||
|
|
||||||
|
@ -33,8 +33,8 @@ where
|
||||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
|
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
|
||||||
U: Encoder + Decoder,
|
U: Encoder + Decoder,
|
||||||
{
|
{
|
||||||
|
io: IoBoxed,
|
||||||
st: Cell<DispatcherState>,
|
st: Cell<DispatcherState>,
|
||||||
state: IoBoxed,
|
|
||||||
timer: Timer,
|
timer: Timer,
|
||||||
ka_timeout: Seconds,
|
ka_timeout: Seconds,
|
||||||
ka_updated: Cell<time::Instant>,
|
ka_updated: Cell<time::Instant>,
|
||||||
|
@ -90,7 +90,7 @@ where
|
||||||
{
|
{
|
||||||
/// Construct new `Dispatcher` instance.
|
/// Construct new `Dispatcher` instance.
|
||||||
pub fn new<F: IntoService<S>>(
|
pub fn new<F: IntoService<S>>(
|
||||||
state: IoBoxed,
|
io: IoBoxed,
|
||||||
codec: U,
|
codec: U,
|
||||||
service: F,
|
service: F,
|
||||||
timer: Timer,
|
timer: Timer,
|
||||||
|
@ -100,13 +100,13 @@ where
|
||||||
|
|
||||||
// register keepalive timer
|
// register keepalive timer
|
||||||
let expire = updated + time::Duration::from(ka_timeout);
|
let expire = updated + time::Duration::from(ka_timeout);
|
||||||
timer.register(expire, expire, &state);
|
timer.register(expire, expire, &io);
|
||||||
|
|
||||||
Dispatcher {
|
Dispatcher {
|
||||||
service: service.into_service(),
|
service: service.into_service(),
|
||||||
fut: None,
|
fut: None,
|
||||||
inner: DispatcherInner {
|
inner: DispatcherInner {
|
||||||
pool: state.memory_pool().pool(),
|
pool: io.memory_pool().pool(),
|
||||||
ka_updated: Cell::new(updated),
|
ka_updated: Cell::new(updated),
|
||||||
error: Cell::new(None),
|
error: Cell::new(None),
|
||||||
ready_err: Cell::new(false),
|
ready_err: Cell::new(false),
|
||||||
|
@ -116,7 +116,7 @@ where
|
||||||
error: Cell::new(None),
|
error: Cell::new(None),
|
||||||
inflight: Cell::new(0),
|
inflight: Cell::new(0),
|
||||||
}),
|
}),
|
||||||
state,
|
io,
|
||||||
timer,
|
timer,
|
||||||
ka_timeout,
|
ka_timeout,
|
||||||
},
|
},
|
||||||
|
@ -132,10 +132,10 @@ where
|
||||||
// register keepalive timer
|
// register keepalive timer
|
||||||
let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka());
|
let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka());
|
||||||
if timeout.is_zero() {
|
if timeout.is_zero() {
|
||||||
self.inner.timer.unregister(prev, &self.inner.state);
|
self.inner.timer.unregister(prev, &self.inner.io);
|
||||||
} else {
|
} else {
|
||||||
let expire = self.inner.ka_updated.get() + time::Duration::from(timeout);
|
let expire = self.inner.ka_updated.get() + time::Duration::from(timeout);
|
||||||
self.inner.timer.register(expire, prev, &self.inner.state);
|
self.inner.timer.register(expire, prev, &self.inner.io);
|
||||||
}
|
}
|
||||||
self.inner.ka_timeout = timeout;
|
self.inner.ka_timeout = timeout;
|
||||||
|
|
||||||
|
@ -151,7 +151,7 @@ where
|
||||||
///
|
///
|
||||||
/// By default disconnect timeout is set to 1 seconds.
|
/// By default disconnect timeout is set to 1 seconds.
|
||||||
pub fn disconnect_timeout(self, val: Seconds) -> Self {
|
pub fn disconnect_timeout(self, val: Seconds) -> Self {
|
||||||
self.inner.state.set_disconnect_timeout(val.into());
|
self.inner.io.set_disconnect_timeout(val.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -161,18 +161,18 @@ where
|
||||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
|
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
|
||||||
U: Encoder + Decoder + 'static,
|
U: Encoder + Decoder + 'static,
|
||||||
{
|
{
|
||||||
fn handle_result(&self, item: Result<S::Response, S::Error>, write: WriteRef<'_>) {
|
fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoRef) {
|
||||||
self.inflight.set(self.inflight.get() - 1);
|
self.inflight.set(self.inflight.get() - 1);
|
||||||
match item {
|
match item {
|
||||||
Ok(Some(val)) => match write.encode(val, &self.codec) {
|
Ok(Some(val)) => match io.encode(val, &self.codec) {
|
||||||
Ok(true) => (),
|
Ok(true) => (),
|
||||||
Ok(false) => write.enable_backpressure(None),
|
Ok(false) => io.enable_write_backpressure(),
|
||||||
Err(err) => self.error.set(Some(DispatcherError::Encoder(err))),
|
Err(err) => self.error.set(Some(DispatcherError::Encoder(err))),
|
||||||
},
|
},
|
||||||
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
|
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
|
||||||
Ok(None) => return,
|
Ok(None) => return,
|
||||||
}
|
}
|
||||||
write.wake_dispatcher();
|
io.wake_dispatcher();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,9 +186,8 @@ where
|
||||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
let mut this = self.as_mut().project();
|
let mut this = self.as_mut().project();
|
||||||
let slf = &this.inner;
|
let slf = &this.inner;
|
||||||
let state = &slf.state;
|
let io = &slf.io;
|
||||||
let read = state.read();
|
let ioref = io.as_ref();
|
||||||
let write = state.write();
|
|
||||||
|
|
||||||
// handle service response future
|
// handle service response future
|
||||||
if let Some(fut) = this.fut.as_mut().as_pin_mut() {
|
if let Some(fut) = this.fut.as_mut().as_pin_mut() {
|
||||||
|
@ -197,79 +196,58 @@ where
|
||||||
Poll::Ready(item) => {
|
Poll::Ready(item) => {
|
||||||
this.fut.set(None);
|
this.fut.set(None);
|
||||||
slf.shared.inflight.set(slf.shared.inflight.get() - 1);
|
slf.shared.inflight.set(slf.shared.inflight.get() - 1);
|
||||||
slf.handle_result(item, write);
|
slf.handle_result(item, ioref);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle memory pool pressure
|
// handle memory pool pressure
|
||||||
if slf.pool.poll_ready(cx).is_pending() {
|
if slf.pool.poll_ready(cx).is_pending() {
|
||||||
read.pause(cx);
|
io.pause(cx);
|
||||||
return Poll::Pending;
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match slf.st.get() {
|
match slf.st.get() {
|
||||||
DispatcherState::Processing => {
|
DispatcherState::Processing => {
|
||||||
let result = match slf.poll_service(this.service, cx, read) {
|
let result = if let Poll::Ready(result) =
|
||||||
Poll::Pending => {
|
slf.poll_service(this.service, cx, io)
|
||||||
if let Err(err) = read.poll_read_ready(cx) {
|
{
|
||||||
log::error!(
|
result
|
||||||
"io error while service is in pending state: {:?}",
|
} else {
|
||||||
err
|
return Poll::Pending;
|
||||||
);
|
|
||||||
return Poll::Ready(Ok(()));
|
|
||||||
}
|
|
||||||
return Poll::Pending;
|
|
||||||
}
|
|
||||||
Poll::Ready(result) => result,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let item = match result {
|
let item = match result {
|
||||||
PollService::Ready => {
|
PollService::Ready => {
|
||||||
if !write.is_ready() {
|
if !io.is_write_ready() {
|
||||||
// instruct write task to notify dispatcher when data is flushed
|
// instruct write task to notify dispatcher when data is flushed
|
||||||
write.enable_backpressure(Some(cx));
|
io.enable_write_backpressure(cx);
|
||||||
slf.st.set(DispatcherState::Backpressure);
|
slf.st.set(DispatcherState::Backpressure);
|
||||||
DispatchItem::WBackPressureEnabled
|
DispatchItem::WBackPressureEnabled
|
||||||
} else if read.is_ready() {
|
} else {
|
||||||
// decode incoming bytes if buffer is ready
|
// decode incoming bytes if buffer is ready
|
||||||
match read.decode(&slf.shared.codec) {
|
match io.poll_read_next(&slf.shared.codec, cx) {
|
||||||
Ok(Some(el)) => {
|
Poll::Ready(Some(Ok(el))) => {
|
||||||
slf.update_keepalive();
|
slf.update_keepalive();
|
||||||
DispatchItem::Item(el)
|
DispatchItem::Item(el)
|
||||||
}
|
}
|
||||||
Ok(None) => {
|
Poll::Ready(Some(Err(Either::Left(err)))) => {
|
||||||
log::trace!("not enough data to decode next frame, register dispatch task");
|
|
||||||
// service is ready, wake io read task
|
|
||||||
match read.poll_read_ready(cx) {
|
|
||||||
Ok(()) => {
|
|
||||||
read.resume();
|
|
||||||
return Poll::Pending;
|
|
||||||
}
|
|
||||||
Err(None) => DispatchItem::Disconnect(None),
|
|
||||||
Err(Some(err)) => {
|
|
||||||
DispatchItem::Disconnect(Some(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
slf.st.set(DispatcherState::Stop);
|
slf.st.set(DispatcherState::Stop);
|
||||||
slf.unregister_keepalive();
|
slf.unregister_keepalive();
|
||||||
DispatchItem::DecoderError(err)
|
DispatchItem::DecoderError(err)
|
||||||
}
|
}
|
||||||
}
|
Poll::Ready(Some(Err(Either::Right(err)))) => {
|
||||||
} else {
|
slf.st.set(DispatcherState::Stop);
|
||||||
// no new events
|
slf.unregister_keepalive();
|
||||||
match read.poll_read_ready(cx) {
|
|
||||||
Ok(()) => {
|
|
||||||
read.resume();
|
|
||||||
return Poll::Pending;
|
|
||||||
}
|
|
||||||
Err(None) => DispatchItem::Disconnect(None),
|
|
||||||
Err(Some(err)) => {
|
|
||||||
DispatchItem::Disconnect(Some(err))
|
DispatchItem::Disconnect(Some(err))
|
||||||
}
|
}
|
||||||
|
Poll::Ready(None) => DispatchItem::Disconnect(None),
|
||||||
|
Poll::Pending => {
|
||||||
|
log::trace!("not enough data to decode next frame, register dispatch task");
|
||||||
|
io.resume();
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -284,7 +262,7 @@ where
|
||||||
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
|
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
|
||||||
Poll::Ready(res) => {
|
Poll::Ready(res) => {
|
||||||
this.fut.set(None);
|
this.fut.set(None);
|
||||||
slf.handle_result(res, write);
|
slf.handle_result(res, ioref);
|
||||||
}
|
}
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
|
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
|
||||||
|
@ -296,13 +274,13 @@ where
|
||||||
}
|
}
|
||||||
// handle write back-pressure
|
// handle write back-pressure
|
||||||
DispatcherState::Backpressure => {
|
DispatcherState::Backpressure => {
|
||||||
let result = match slf.poll_service(this.service, cx, read) {
|
let result = match slf.poll_service(this.service, cx, io) {
|
||||||
Poll::Ready(result) => result,
|
Poll::Ready(result) => result,
|
||||||
Poll::Pending => return Poll::Pending,
|
Poll::Pending => return Poll::Pending,
|
||||||
};
|
};
|
||||||
let item = match result {
|
let item = match result {
|
||||||
PollService::Ready => {
|
PollService::Ready => {
|
||||||
if write.is_ready() {
|
if io.is_write_ready() {
|
||||||
slf.st.set(DispatcherState::Processing);
|
slf.st.set(DispatcherState::Processing);
|
||||||
DispatchItem::WBackPressureDisabled
|
DispatchItem::WBackPressureDisabled
|
||||||
} else {
|
} else {
|
||||||
|
@ -320,7 +298,7 @@ where
|
||||||
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
|
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
|
||||||
Poll::Ready(res) => {
|
Poll::Ready(res) => {
|
||||||
this.fut.set(None);
|
this.fut.set(None);
|
||||||
slf.handle_result(res, write);
|
slf.handle_result(res, ioref);
|
||||||
}
|
}
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
|
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
|
||||||
|
@ -338,11 +316,14 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
if slf.shared.inflight.get() == 0 {
|
if slf.shared.inflight.get() == 0 {
|
||||||
slf.st.set(DispatcherState::Shutdown);
|
if io.poll_shutdown(cx).is_ready() {
|
||||||
|
slf.st.set(DispatcherState::Shutdown);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
state.register_dispatcher(cx);
|
slf.io.register_dispatcher(cx);
|
||||||
return Poll::Pending;
|
|
||||||
}
|
}
|
||||||
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
// shutdown service
|
// shutdown service
|
||||||
DispatcherState::Shutdown => {
|
DispatcherState::Shutdown => {
|
||||||
|
@ -375,23 +356,23 @@ where
|
||||||
fn spawn_service_call(&self, fut: S::Future) {
|
fn spawn_service_call(&self, fut: S::Future) {
|
||||||
self.shared.inflight.set(self.shared.inflight.get() + 1);
|
self.shared.inflight.set(self.shared.inflight.get() + 1);
|
||||||
|
|
||||||
let st = self.state.get_ref();
|
let st = self.io.get_ref();
|
||||||
let shared = self.shared.clone();
|
let shared = self.shared.clone();
|
||||||
spawn(async move {
|
spawn(async move {
|
||||||
let item = fut.await;
|
let item = fut.await;
|
||||||
shared.handle_result(item, st.write());
|
shared.handle_result(item, &st);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_result(
|
fn handle_result(
|
||||||
&self,
|
&self,
|
||||||
item: Result<Option<<U as Encoder>::Item>, S::Error>,
|
item: Result<Option<<U as Encoder>::Item>, S::Error>,
|
||||||
write: WriteRef<'_>,
|
io: &IoRef,
|
||||||
) {
|
) {
|
||||||
match item {
|
match item {
|
||||||
Ok(Some(item)) => match write.encode(item, &self.shared.codec) {
|
Ok(Some(item)) => match io.encode(item, &self.shared.codec) {
|
||||||
Ok(true) => (),
|
Ok(true) => (),
|
||||||
Ok(false) => write.enable_backpressure(None),
|
Ok(false) => io.enable_write_backpressure(),
|
||||||
Err(err) => self.shared.error.set(Some(DispatcherError::Encoder(err))),
|
Err(err) => self.shared.error.set(Some(DispatcherError::Encoder(err))),
|
||||||
},
|
},
|
||||||
Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))),
|
Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))),
|
||||||
|
@ -403,7 +384,7 @@ where
|
||||||
&self,
|
&self,
|
||||||
srv: &S,
|
srv: &S,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
read: ReadRef<'_>,
|
io: &IoBoxed,
|
||||||
) -> Poll<PollService<U>> {
|
) -> Poll<PollService<U>> {
|
||||||
match srv.poll_ready(cx) {
|
match srv.poll_ready(cx) {
|
||||||
Poll::Ready(Ok(_)) => {
|
Poll::Ready(Ok(_)) => {
|
||||||
|
@ -428,19 +409,19 @@ where
|
||||||
PollService::ServiceError
|
PollService::ServiceError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if self.state.is_dispatcher_stopped() {
|
} else if self.io.is_dispatcher_stopped() {
|
||||||
log::trace!("dispatcher is instructed to stop");
|
log::trace!("dispatcher is instructed to stop");
|
||||||
|
|
||||||
self.unregister_keepalive();
|
self.unregister_keepalive();
|
||||||
|
|
||||||
// process unhandled data
|
// process unhandled data
|
||||||
if let Ok(Some(el)) = read.decode(&self.shared.codec) {
|
if let Ok(Some(el)) = io.decode(&self.shared.codec) {
|
||||||
PollService::Item(DispatchItem::Item(el))
|
PollService::Item(DispatchItem::Item(el))
|
||||||
} else {
|
} else {
|
||||||
self.st.set(DispatcherState::Stop);
|
self.st.set(DispatcherState::Stop);
|
||||||
|
|
||||||
// get io error
|
// get io error
|
||||||
if let Some(err) = self.state.take_error() {
|
if let Some(err) = self.io.take_error() {
|
||||||
PollService::Item(DispatchItem::Disconnect(Some(err)))
|
PollService::Item(DispatchItem::Disconnect(Some(err)))
|
||||||
} else {
|
} else {
|
||||||
PollService::ServiceError
|
PollService::ServiceError
|
||||||
|
@ -453,7 +434,7 @@ where
|
||||||
// pause io read task
|
// pause io read task
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
log::trace!("service is not ready, register dispatch task");
|
log::trace!("service is not ready, register dispatch task");
|
||||||
read.pause(cx);
|
io.pause(cx);
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
// handle service readiness error
|
// handle service readiness error
|
||||||
|
@ -478,7 +459,7 @@ where
|
||||||
|
|
||||||
/// check keepalive timeout
|
/// check keepalive timeout
|
||||||
fn check_keepalive(&self) {
|
fn check_keepalive(&self) {
|
||||||
if self.state.is_keepalive() {
|
if self.io.is_keepalive() {
|
||||||
log::trace!("keepalive timeout");
|
log::trace!("keepalive timeout");
|
||||||
if let Some(err) = self.shared.error.take() {
|
if let Some(err) = self.shared.error.take() {
|
||||||
self.shared.error.set(Some(err));
|
self.shared.error.set(Some(err));
|
||||||
|
@ -494,11 +475,8 @@ where
|
||||||
let updated = now();
|
let updated = now();
|
||||||
if updated != self.ka_updated.get() {
|
if updated != self.ka_updated.get() {
|
||||||
let ka = time::Duration::from(self.ka());
|
let ka = time::Duration::from(self.ka());
|
||||||
self.timer.register(
|
self.timer
|
||||||
updated + ka,
|
.register(updated + ka, self.ka_updated.get() + ka, &self.io);
|
||||||
self.ka_updated.get() + ka,
|
|
||||||
&self.state,
|
|
||||||
);
|
|
||||||
self.ka_updated.set(updated);
|
self.ka_updated.set(updated);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -509,7 +487,7 @@ where
|
||||||
if self.ka_enabled() {
|
if self.ka_enabled() {
|
||||||
self.timer.unregister(
|
self.timer.unregister(
|
||||||
self.ka_updated.get() + time::Duration::from(self.ka()),
|
self.ka_updated.get() + time::Duration::from(self.ka()),
|
||||||
&self.state,
|
&self.io,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -527,7 +505,7 @@ mod tests {
|
||||||
use ntex_util::time::{sleep, Millis};
|
use ntex_util::time::{sleep, Millis};
|
||||||
|
|
||||||
use crate::testing::IoTest;
|
use crate::testing::IoTest;
|
||||||
use crate::{state::Flags, Io, IoRef, IoStream, WriteRef};
|
use crate::{io::Flags, Io, IoRef, IoStream};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
@ -538,8 +516,8 @@ mod tests {
|
||||||
self.0.flags()
|
self.0.flags()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write(&'_ self) -> WriteRef<'_> {
|
fn io(&self) -> &IoRef {
|
||||||
WriteRef(&self.0)
|
&self.0
|
||||||
}
|
}
|
||||||
|
|
||||||
fn close(&self) {
|
fn close(&self) {
|
||||||
|
@ -587,7 +565,7 @@ mod tests {
|
||||||
ready_err: Cell::new(false),
|
ready_err: Cell::new(false),
|
||||||
st: Cell::new(DispatcherState::Processing),
|
st: Cell::new(DispatcherState::Processing),
|
||||||
pool: state.memory_pool().pool(),
|
pool: state.memory_pool().pool(),
|
||||||
state: state.into_boxed(),
|
io: state.into_boxed(),
|
||||||
shared,
|
shared,
|
||||||
timer,
|
timer,
|
||||||
ka_timeout,
|
ka_timeout,
|
||||||
|
@ -634,7 +612,6 @@ mod tests {
|
||||||
|
|
||||||
#[ntex::test]
|
#[ntex::test]
|
||||||
async fn test_sink() {
|
async fn test_sink() {
|
||||||
env_logger::init();
|
|
||||||
let (client, server) = IoTest::create();
|
let (client, server) = IoTest::create();
|
||||||
client.remote_buffer_cap(1024);
|
client.remote_buffer_cap(1024);
|
||||||
client.write("GET /test HTTP/1\r\n\r\n");
|
client.write("GET /test HTTP/1\r\n\r\n");
|
||||||
|
@ -658,7 +635,7 @@ mod tests {
|
||||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||||
|
|
||||||
assert!(st
|
assert!(st
|
||||||
.write()
|
.io()
|
||||||
.encode(Bytes::from_static(b"test"), &mut BytesCodec)
|
.encode(Bytes::from_static(b"test"), &mut BytesCodec)
|
||||||
.is_ok());
|
.is_ok());
|
||||||
let buf = client.read().await.unwrap();
|
let buf = client.read().await.unwrap();
|
||||||
|
@ -684,7 +661,7 @@ mod tests {
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
state
|
state
|
||||||
.write()
|
.io()
|
||||||
.encode(
|
.encode(
|
||||||
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
|
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
|
||||||
&mut BytesCodec,
|
&mut BytesCodec,
|
||||||
|
@ -737,7 +714,7 @@ mod tests {
|
||||||
|
|
||||||
let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone()));
|
let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone()));
|
||||||
state
|
state
|
||||||
.write()
|
.io()
|
||||||
.encode(
|
.encode(
|
||||||
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
|
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
|
||||||
&mut BytesCodec,
|
&mut BytesCodec,
|
||||||
|
@ -821,19 +798,19 @@ mod tests {
|
||||||
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
|
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
|
||||||
|
|
||||||
// response message
|
// response message
|
||||||
assert!(!state.write().is_ready());
|
assert!(!state.io().is_write_ready());
|
||||||
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 65536);
|
assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 65536);
|
||||||
|
|
||||||
client.remote_buffer_cap(10240);
|
client.remote_buffer_cap(10240);
|
||||||
sleep(Millis(50)).await;
|
sleep(Millis(50)).await;
|
||||||
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 55296);
|
assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 55296);
|
||||||
|
|
||||||
client.remote_buffer_cap(45056);
|
client.remote_buffer_cap(45056);
|
||||||
sleep(Millis(50)).await;
|
sleep(Millis(50)).await;
|
||||||
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 10240);
|
assert_eq!(state.io().with_write_buf(|buf| buf.len()).unwrap(), 10240);
|
||||||
|
|
||||||
// backpressure disabled
|
// backpressure disabled
|
||||||
assert!(state.write().is_ready());
|
assert!(state.io().is_write_ready());
|
||||||
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
|
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,8 @@ use std::{any, io, task::Context, task::Poll};
|
||||||
|
|
||||||
use ntex_bytes::BytesMut;
|
use ntex_bytes::BytesMut;
|
||||||
|
|
||||||
use super::state::{Flags, IoRef};
|
use super::io::Flags;
|
||||||
use super::{Filter, WriteReadiness};
|
use super::{Filter, IoRef, WriteReadiness};
|
||||||
|
|
||||||
pub struct DefaultFilter(IoRef);
|
pub struct DefaultFilter(IoRef);
|
||||||
|
|
||||||
|
|
698
ntex-io/src/io.rs
Normal file
698
ntex-io/src/io.rs
Normal file
|
@ -0,0 +1,698 @@
|
||||||
|
use std::cell::{Cell, RefCell};
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use std::{fmt, future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
|
||||||
|
|
||||||
|
use ntex_bytes::{BytesMut, PoolId, PoolRef};
|
||||||
|
use ntex_codec::{Decoder, Encoder};
|
||||||
|
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker, time::Millis};
|
||||||
|
|
||||||
|
use super::filter::{DefaultFilter, NullFilter};
|
||||||
|
use super::tasks::{ReadContext, WriteContext};
|
||||||
|
use super::{Filter, FilterFactory, Handle, IoStream};
|
||||||
|
|
||||||
|
bitflags::bitflags! {
|
||||||
|
pub struct Flags: u16 {
|
||||||
|
/// io error occured
|
||||||
|
const IO_ERR = 0b0000_0000_0000_0001;
|
||||||
|
/// shuting down filters
|
||||||
|
const IO_FILTERS = 0b0000_0000_0000_0010;
|
||||||
|
/// shuting down filters timeout
|
||||||
|
const IO_FILTERS_TO = 0b0000_0000_0000_0100;
|
||||||
|
/// shutdown io tasks
|
||||||
|
const IO_SHUTDOWN = 0b0000_0000_0000_1000;
|
||||||
|
/// io object is closed
|
||||||
|
const IO_CLOSED = 0b0000_0000_0001_0000;
|
||||||
|
|
||||||
|
/// pause io read
|
||||||
|
const RD_PAUSED = 0b0000_0000_0010_0000;
|
||||||
|
/// new data is available
|
||||||
|
const RD_READY = 0b0000_0000_0100_0000;
|
||||||
|
/// read buffer is full
|
||||||
|
const RD_BUF_FULL = 0b0000_0000_1000_0000;
|
||||||
|
|
||||||
|
/// wait write completion
|
||||||
|
const WR_WAIT = 0b0000_0001_0000_0000;
|
||||||
|
/// write buffer is full
|
||||||
|
const WR_BACKPRESSURE = 0b0000_0010_0000_0000;
|
||||||
|
|
||||||
|
/// dispatcher is marked stopped
|
||||||
|
const DSP_STOP = 0b0001_0000_0000_0000;
|
||||||
|
/// keep-alive timeout occured
|
||||||
|
const DSP_KEEPALIVE = 0b0010_0000_0000_0000;
|
||||||
|
/// dispatcher returned error
|
||||||
|
const DSP_ERR = 0b0100_0000_0000_0000;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum FilterItem<F> {
|
||||||
|
Boxed(Box<dyn Filter>),
|
||||||
|
Ptr(*mut F),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Io<F = DefaultFilter>(pub(super) IoRef, FilterItem<F>);
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct IoRef(pub(super) Rc<IoState>);
|
||||||
|
|
||||||
|
pub(crate) struct IoState {
|
||||||
|
pub(super) flags: Cell<Flags>,
|
||||||
|
pub(super) pool: Cell<PoolRef>,
|
||||||
|
pub(super) disconnect_timeout: Cell<Millis>,
|
||||||
|
pub(super) error: Cell<Option<io::Error>>,
|
||||||
|
pub(super) read_task: LocalWaker,
|
||||||
|
pub(super) write_task: LocalWaker,
|
||||||
|
pub(super) dispatch_task: LocalWaker,
|
||||||
|
pub(super) read_buf: Cell<Option<BytesMut>>,
|
||||||
|
pub(super) write_buf: Cell<Option<BytesMut>>,
|
||||||
|
pub(super) filter: Cell<&'static dyn Filter>,
|
||||||
|
pub(super) handle: Cell<Option<Box<dyn Handle>>>,
|
||||||
|
pub(super) on_disconnect: RefCell<Vec<Option<LocalWaker>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IoState {
|
||||||
|
#[inline]
|
||||||
|
pub(super) fn insert_flags(&self, f: Flags) {
|
||||||
|
let mut flags = self.flags.get();
|
||||||
|
flags.insert(f);
|
||||||
|
self.flags.set(flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(super) fn remove_flags(&self, f: Flags) {
|
||||||
|
let mut flags = self.flags.get();
|
||||||
|
flags.remove(f);
|
||||||
|
self.flags.set(flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(super) fn notify_keepalive(&self) {
|
||||||
|
let mut flags = self.flags.get();
|
||||||
|
if !flags.contains(Flags::DSP_KEEPALIVE) {
|
||||||
|
flags.insert(Flags::DSP_KEEPALIVE);
|
||||||
|
self.flags.set(flags);
|
||||||
|
self.dispatch_task.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(super) fn notify_disconnect(&self) {
|
||||||
|
let mut on_disconnect = self.on_disconnect.borrow_mut();
|
||||||
|
for item in &mut *on_disconnect {
|
||||||
|
if let Some(waker) = item.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(super) fn is_io_open(&self) -> bool {
|
||||||
|
!self.flags.get().intersects(
|
||||||
|
Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_SHUTDOWN | Flags::IO_CLOSED,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(super) fn set_error(&self, err: Option<io::Error>) {
|
||||||
|
if err.is_some() {
|
||||||
|
self.error.set(err);
|
||||||
|
}
|
||||||
|
self.read_task.wake();
|
||||||
|
self.write_task.wake();
|
||||||
|
self.dispatch_task.wake();
|
||||||
|
self.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
|
||||||
|
self.notify_disconnect();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Gracefully shutdown read and write io tasks
|
||||||
|
pub(super) fn init_shutdown(&self, cx: Option<&mut Context<'_>>, st: &IoRef) {
|
||||||
|
let flags = self.flags.get();
|
||||||
|
|
||||||
|
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
|
||||||
|
log::trace!("initiate io shutdown {:?}", flags);
|
||||||
|
self.insert_flags(Flags::IO_FILTERS);
|
||||||
|
if let Err(err) = self.shutdown_filters(st) {
|
||||||
|
self.error.set(Some(err));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.read_task.wake();
|
||||||
|
self.write_task.wake();
|
||||||
|
if let Some(cx) = cx {
|
||||||
|
self.dispatch_task.register(cx.waker());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub(super) fn shutdown_filters(&self, st: &IoRef) -> Result<(), io::Error> {
|
||||||
|
let mut flags = self.flags.get();
|
||||||
|
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||||
|
let result = match self.filter.get().shutdown(st) {
|
||||||
|
Poll::Pending => return Ok(()),
|
||||||
|
Poll::Ready(Ok(())) => {
|
||||||
|
flags.insert(Flags::IO_SHUTDOWN);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Poll::Ready(Err(err)) => {
|
||||||
|
flags.insert(Flags::IO_ERR);
|
||||||
|
self.dispatch_task.wake();
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.flags.set(flags);
|
||||||
|
self.read_task.wake();
|
||||||
|
self.write_task.wake();
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Eq for IoState {}
|
||||||
|
|
||||||
|
impl PartialEq for IoState {
|
||||||
|
#[inline]
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
ptr::eq(self, other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl hash::Hash for IoState {
|
||||||
|
#[inline]
|
||||||
|
fn hash<H: hash::Hasher>(&self, state: &mut H) {
|
||||||
|
(self as *const _ as usize).hash(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for IoState {
|
||||||
|
#[inline]
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(buf) = self.read_buf.take() {
|
||||||
|
self.pool.get().release_read_buf(buf);
|
||||||
|
}
|
||||||
|
if let Some(buf) = self.write_buf.take() {
|
||||||
|
self.pool.get().release_write_buf(buf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Io {
|
||||||
|
#[inline]
|
||||||
|
/// Create `State` instance
|
||||||
|
pub fn new<I: IoStream>(io: I) -> Self {
|
||||||
|
Self::with_memory_pool(io, PoolId::DEFAULT.pool_ref())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Create `State` instance with specific memory pool.
|
||||||
|
pub fn with_memory_pool<I: IoStream>(io: I, pool: PoolRef) -> Self {
|
||||||
|
let inner = Rc::new(IoState {
|
||||||
|
pool: Cell::new(pool),
|
||||||
|
flags: Cell::new(Flags::empty()),
|
||||||
|
error: Cell::new(None),
|
||||||
|
disconnect_timeout: Cell::new(Millis::ONE_SEC),
|
||||||
|
dispatch_task: LocalWaker::new(),
|
||||||
|
read_task: LocalWaker::new(),
|
||||||
|
write_task: LocalWaker::new(),
|
||||||
|
read_buf: Cell::new(None),
|
||||||
|
write_buf: Cell::new(None),
|
||||||
|
filter: Cell::new(NullFilter::get()),
|
||||||
|
handle: Cell::new(None),
|
||||||
|
on_disconnect: RefCell::new(Vec::new()),
|
||||||
|
});
|
||||||
|
|
||||||
|
let filter = Box::new(DefaultFilter::new(IoRef(inner.clone())));
|
||||||
|
let filter_ref: &'static dyn Filter = unsafe {
|
||||||
|
let filter: &dyn Filter = filter.as_ref();
|
||||||
|
std::mem::transmute(filter)
|
||||||
|
};
|
||||||
|
inner.filter.replace(filter_ref);
|
||||||
|
|
||||||
|
let io_ref = IoRef(inner);
|
||||||
|
|
||||||
|
// start io tasks
|
||||||
|
let hnd = io.start(ReadContext(io_ref.clone()), WriteContext(io_ref.clone()));
|
||||||
|
io_ref.0.handle.set(hnd);
|
||||||
|
|
||||||
|
Io(io_ref, FilterItem::Ptr(Box::into_raw(filter)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Io<F> {
|
||||||
|
#[inline]
|
||||||
|
/// Set memory pool
|
||||||
|
pub fn set_memory_pool(&self, pool: PoolRef) {
|
||||||
|
if let Some(mut buf) = self.0 .0.read_buf.take() {
|
||||||
|
pool.move_in(&mut buf);
|
||||||
|
self.0 .0.read_buf.set(Some(buf));
|
||||||
|
}
|
||||||
|
if let Some(mut buf) = self.0 .0.write_buf.take() {
|
||||||
|
pool.move_in(&mut buf);
|
||||||
|
self.0 .0.write_buf.set(Some(buf));
|
||||||
|
}
|
||||||
|
self.0 .0.pool.set(pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Set io disconnect timeout in secs
|
||||||
|
pub fn set_disconnect_timeout(&self, timeout: Millis) {
|
||||||
|
self.0 .0.disconnect_timeout.set(timeout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Io<F> {
|
||||||
|
#[inline]
|
||||||
|
#[doc(hidden)]
|
||||||
|
/// Get current state flags
|
||||||
|
pub fn flags(&self) -> Flags {
|
||||||
|
self.0 .0.flags.get()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
#[allow(clippy::should_implement_trait)]
|
||||||
|
/// Get IoRef reference
|
||||||
|
pub fn as_ref(&self) -> &IoRef {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Get instance of IoRef
|
||||||
|
pub fn get_ref(&self) -> IoRef {
|
||||||
|
self.0.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Check if dispatcher marked stopped
|
||||||
|
pub fn is_dispatcher_stopped(&self) -> bool {
|
||||||
|
self.flags().contains(Flags::DSP_STOP)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Register dispatcher task
|
||||||
|
pub fn register_dispatcher(&self, cx: &mut Context<'_>) {
|
||||||
|
self.0 .0.dispatch_task.register(cx.waker());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Reset keep-alive error
|
||||||
|
pub fn reset_keepalive(&self) {
|
||||||
|
self.0 .0.remove_flags(Flags::DSP_KEEPALIVE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Filter> Io<F> {
|
||||||
|
#[inline]
|
||||||
|
/// Get referece to filter
|
||||||
|
pub fn filter(&self) -> &F {
|
||||||
|
if let FilterItem::Ptr(p) = self.1 {
|
||||||
|
if let Some(r) = unsafe { p.as_ref() } {
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic!()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn into_boxed(mut self) -> crate::IoBoxed
|
||||||
|
where
|
||||||
|
F: 'static,
|
||||||
|
{
|
||||||
|
// get current filter
|
||||||
|
let filter = unsafe {
|
||||||
|
let item = mem::replace(&mut self.1, FilterItem::Ptr(std::ptr::null_mut()));
|
||||||
|
let filter: Box<dyn Filter> = match item {
|
||||||
|
FilterItem::Boxed(b) => b,
|
||||||
|
FilterItem::Ptr(p) => Box::new(*Box::from_raw(p)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let filter_ref: &'static dyn Filter = {
|
||||||
|
let filter: &dyn Filter = filter.as_ref();
|
||||||
|
std::mem::transmute(filter)
|
||||||
|
};
|
||||||
|
self.0 .0.filter.replace(filter_ref);
|
||||||
|
filter
|
||||||
|
};
|
||||||
|
|
||||||
|
Io(self.0.clone(), FilterItem::Boxed(filter))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn add_filter<T>(self, factory: T) -> T::Future
|
||||||
|
where
|
||||||
|
T: FilterFactory<F>,
|
||||||
|
{
|
||||||
|
factory.create(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn map_filter<T, U, E>(mut self, map: U) -> Result<Io<T>, E>
|
||||||
|
where
|
||||||
|
T: Filter,
|
||||||
|
U: FnOnce(F) -> Result<T, E>,
|
||||||
|
{
|
||||||
|
// replace current filter
|
||||||
|
let filter = unsafe {
|
||||||
|
let item = mem::replace(&mut self.1, FilterItem::Ptr(std::ptr::null_mut()));
|
||||||
|
let filter = match item {
|
||||||
|
FilterItem::Boxed(_) => panic!(),
|
||||||
|
FilterItem::Ptr(p) => {
|
||||||
|
assert!(!p.is_null());
|
||||||
|
Box::new(map(*Box::from_raw(p))?)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let filter_ref: &'static dyn Filter = {
|
||||||
|
let filter: &dyn Filter = filter.as_ref();
|
||||||
|
std::mem::transmute(filter)
|
||||||
|
};
|
||||||
|
self.0 .0.filter.replace(filter_ref);
|
||||||
|
filter
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Io(self.0.clone(), FilterItem::Ptr(Box::into_raw(filter))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Io<F> {
|
||||||
|
#[inline]
|
||||||
|
/// Read incoming io stream and decode codec item.
|
||||||
|
pub async fn next<U>(
|
||||||
|
&self,
|
||||||
|
codec: &U,
|
||||||
|
) -> Option<Result<U::Item, Either<U::Error, io::Error>>>
|
||||||
|
where
|
||||||
|
U: Decoder,
|
||||||
|
{
|
||||||
|
poll_fn(|cx| self.poll_read_next(codec, cx)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Encode item, send to a peer
|
||||||
|
pub async fn send<U>(
|
||||||
|
&self,
|
||||||
|
item: U::Item,
|
||||||
|
codec: &U,
|
||||||
|
) -> Result<(), Either<U::Error, io::Error>>
|
||||||
|
where
|
||||||
|
U: Encoder,
|
||||||
|
{
|
||||||
|
let filter = self.filter();
|
||||||
|
let mut buf = filter
|
||||||
|
.get_write_buf()
|
||||||
|
.unwrap_or_else(|| self.memory_pool().get_write_buf());
|
||||||
|
|
||||||
|
let is_write_sleep = buf.is_empty();
|
||||||
|
codec.encode(item, &mut buf).map_err(Either::Left)?;
|
||||||
|
filter.release_write_buf(buf).map_err(Either::Right)?;
|
||||||
|
if is_write_sleep {
|
||||||
|
self.0 .0.write_task.wake();
|
||||||
|
}
|
||||||
|
|
||||||
|
poll_fn(|cx| self.poll_write_ready(cx, true))
|
||||||
|
.await
|
||||||
|
.map_err(Either::Right)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Wake write task and instruct to write data.
|
||||||
|
///
|
||||||
|
/// This is async version of .poll_write_ready() method.
|
||||||
|
pub async fn write_ready(&self, full: bool) -> Result<(), io::Error> {
|
||||||
|
poll_fn(|cx| self.poll_write_ready(cx, full)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Shut down connection
|
||||||
|
pub async fn shutdown(&self) -> Result<(), io::Error> {
|
||||||
|
poll_fn(|cx| self.poll_shutdown(cx)).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Io<F> {
|
||||||
|
#[inline]
|
||||||
|
/// Wake write task and instruct to write data.
|
||||||
|
///
|
||||||
|
/// If full is true then wake up dispatcher when all data is flushed
|
||||||
|
/// otherwise wake up when size of write buffer is lower than
|
||||||
|
/// buffer max size.
|
||||||
|
pub fn poll_write_ready(
|
||||||
|
&self,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
full: bool,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
// check io error
|
||||||
|
if !self.0 .0.is_io_open() {
|
||||||
|
return Poll::Ready(Err(self.0 .0.error.take().unwrap_or_else(|| {
|
||||||
|
io::Error::new(io::ErrorKind::Other, "disconnected")
|
||||||
|
})));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(buf) = self.0 .0.write_buf.take() {
|
||||||
|
let len = buf.len();
|
||||||
|
if len != 0 {
|
||||||
|
self.0 .0.write_buf.set(Some(buf));
|
||||||
|
|
||||||
|
if full {
|
||||||
|
self.0 .0.insert_flags(Flags::WR_WAIT);
|
||||||
|
self.0 .0.dispatch_task.register(cx.waker());
|
||||||
|
return Poll::Pending;
|
||||||
|
} else if len >= self.0.memory_pool().write_params_high() << 1 {
|
||||||
|
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
|
||||||
|
self.0 .0.dispatch_task.register(cx.waker());
|
||||||
|
return Poll::Pending;
|
||||||
|
} else {
|
||||||
|
self.0 .0.remove_flags(Flags::WR_BACKPRESSURE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Wake read task and instruct to read more data
|
||||||
|
///
|
||||||
|
/// Read task is awake only if back-pressure is enabled
|
||||||
|
/// otherwise it is already awake. Buffer read status gets clean up.
|
||||||
|
pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Option<io::Result<()>>> {
|
||||||
|
if !self.0 .0.is_io_open() {
|
||||||
|
if let Some(err) = self.0 .0.error.take() {
|
||||||
|
Poll::Ready(Some(Err(err)))
|
||||||
|
} else {
|
||||||
|
Poll::Ready(None)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.0 .0.dispatch_task.register(cx.waker());
|
||||||
|
|
||||||
|
let mut flags = self.0 .0.flags.get();
|
||||||
|
let ready = flags.contains(Flags::RD_READY);
|
||||||
|
if flags.contains(Flags::RD_BUF_FULL) {
|
||||||
|
log::trace!("read back-pressure is disabled, wake io task");
|
||||||
|
flags.remove(Flags::RD_READY | Flags::RD_BUF_FULL);
|
||||||
|
self.0 .0.read_task.wake();
|
||||||
|
self.0 .0.flags.set(flags);
|
||||||
|
if ready {
|
||||||
|
Poll::Ready(Some(Ok(())))
|
||||||
|
} else {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
} else if ready {
|
||||||
|
log::trace!("waking up io read task");
|
||||||
|
flags.remove(Flags::RD_READY);
|
||||||
|
self.0 .0.flags.set(flags);
|
||||||
|
self.0 .0.read_task.wake();
|
||||||
|
Poll::Ready(Some(Ok(())))
|
||||||
|
} else {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
pub fn poll_read_next<U>(
|
||||||
|
&self,
|
||||||
|
codec: &U,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<U::Item, Either<U::Error, io::Error>>>>
|
||||||
|
where
|
||||||
|
U: Decoder,
|
||||||
|
{
|
||||||
|
match self.decode(codec) {
|
||||||
|
Ok(Some(el)) => Poll::Ready(Some(Ok(el))),
|
||||||
|
Ok(None) => match self.poll_read_ready(cx) {
|
||||||
|
Poll::Pending | Poll::Ready(Some(Ok(()))) => Poll::Pending,
|
||||||
|
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(Either::Right(e)))),
|
||||||
|
Poll::Ready(None) => Poll::Ready(None),
|
||||||
|
},
|
||||||
|
Err(err) => Poll::Ready(Some(Err(Either::Left(err)))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Shut down connection
|
||||||
|
pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||||
|
let flags = self.flags();
|
||||||
|
|
||||||
|
if flags.intersects(Flags::IO_ERR | Flags::IO_CLOSED) {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
if !flags.contains(Flags::IO_FILTERS) {
|
||||||
|
self.0 .0.init_shutdown(Some(cx), self.as_ref());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(err) = self.0 .0.error.take() {
|
||||||
|
Poll::Ready(Err(err))
|
||||||
|
} else {
|
||||||
|
self.0 .0.dispatch_task.register(cx.waker());
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Pause read task
|
||||||
|
pub fn pause(&self, cx: &mut Context<'_>) {
|
||||||
|
self.0 .0.insert_flags(Flags::RD_PAUSED);
|
||||||
|
self.0 .0.dispatch_task.register(cx.waker());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Wake read io task if it is paused
|
||||||
|
pub fn resume(&self) -> bool {
|
||||||
|
let flags = self.0 .0.flags.get();
|
||||||
|
if flags.contains(Flags::RD_PAUSED) {
|
||||||
|
self.0 .0.remove_flags(Flags::RD_PAUSED);
|
||||||
|
self.0 .0.read_task.wake();
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Wait until write task flushes data to io stream
|
||||||
|
///
|
||||||
|
/// Write task must be waken up separately.
|
||||||
|
pub fn enable_write_backpressure(&self, cx: &mut Context<'_>) {
|
||||||
|
log::trace!("enable write back-pressure for dispatcher");
|
||||||
|
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
|
||||||
|
self.0 .0.dispatch_task.register(cx.waker());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Drop for Io<F> {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let FilterItem::Ptr(p) = self.1 {
|
||||||
|
if p.is_null() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
log::trace!(
|
||||||
|
"io is dropped, force stopping io streams {:?}",
|
||||||
|
self.0.flags()
|
||||||
|
);
|
||||||
|
|
||||||
|
self.force_close();
|
||||||
|
self.0 .0.filter.set(NullFilter::get());
|
||||||
|
let _ = mem::replace(&mut self.1, FilterItem::Ptr(std::ptr::null_mut()));
|
||||||
|
unsafe { Box::from_raw(p) };
|
||||||
|
} else {
|
||||||
|
log::trace!(
|
||||||
|
"io is dropped, force stopping io streams {:?}",
|
||||||
|
self.0.flags()
|
||||||
|
);
|
||||||
|
self.force_close();
|
||||||
|
self.0 .0.filter.set(NullFilter::get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for Io {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("Io")
|
||||||
|
.field("open", &!self.is_closed())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Deref for Io<F> {
|
||||||
|
type Target = IoRef;
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// OnDisconnect future resolves when socket get disconnected
|
||||||
|
#[must_use = "OnDisconnect do nothing unless polled"]
|
||||||
|
pub struct OnDisconnect {
|
||||||
|
token: usize,
|
||||||
|
inner: Rc<IoState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OnDisconnect {
|
||||||
|
pub(super) fn new(inner: Rc<IoState>) -> Self {
|
||||||
|
Self::new_inner(inner.flags.get().contains(Flags::IO_ERR), inner)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
|
||||||
|
let token = if disconnected {
|
||||||
|
usize::MAX
|
||||||
|
} else {
|
||||||
|
let mut on_disconnect = inner.on_disconnect.borrow_mut();
|
||||||
|
let token = on_disconnect.len();
|
||||||
|
on_disconnect.push(Some(LocalWaker::default()));
|
||||||
|
drop(on_disconnect);
|
||||||
|
token
|
||||||
|
};
|
||||||
|
Self { token, inner }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Check if connection is disconnected
|
||||||
|
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
|
||||||
|
if self.token == usize::MAX {
|
||||||
|
Poll::Ready(())
|
||||||
|
} else {
|
||||||
|
let on_disconnect = self.inner.on_disconnect.borrow();
|
||||||
|
if on_disconnect[self.token].is_some() {
|
||||||
|
on_disconnect[self.token]
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.register(cx.waker());
|
||||||
|
Poll::Pending
|
||||||
|
} else {
|
||||||
|
Poll::Ready(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for OnDisconnect {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
if self.token == usize::MAX {
|
||||||
|
OnDisconnect::new_inner(true, self.inner.clone())
|
||||||
|
} else {
|
||||||
|
OnDisconnect::new_inner(false, self.inner.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Future for OnDisconnect {
|
||||||
|
type Output = ();
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
self.poll_ready(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for OnDisconnect {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if self.token != usize::MAX {
|
||||||
|
self.inner.on_disconnect.borrow_mut()[self.token].take();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
558
ntex-io/src/ioref.rs
Normal file
558
ntex-io/src/ioref.rs
Normal file
|
@ -0,0 +1,558 @@
|
||||||
|
use std::{any, fmt, io};
|
||||||
|
|
||||||
|
use ntex_bytes::{BytesMut, PoolRef};
|
||||||
|
use ntex_codec::{Decoder, Encoder};
|
||||||
|
|
||||||
|
use super::io::{Flags, IoRef, OnDisconnect};
|
||||||
|
use super::{types, Filter};
|
||||||
|
|
||||||
|
impl IoRef {
|
||||||
|
#[inline]
|
||||||
|
#[doc(hidden)]
|
||||||
|
/// Get current state flags
|
||||||
|
pub fn flags(&self) -> Flags {
|
||||||
|
self.0.flags.get()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Set flags
|
||||||
|
pub(crate) fn set_flags(&self, flags: Flags) {
|
||||||
|
self.0.flags.set(flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Get memory pool
|
||||||
|
pub(crate) fn filter(&self) -> &dyn Filter {
|
||||||
|
self.0.filter.get()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Get memory pool
|
||||||
|
pub fn memory_pool(&self) -> PoolRef {
|
||||||
|
self.0.pool.get()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Check if io is still active
|
||||||
|
pub fn is_io_open(&self) -> bool {
|
||||||
|
self.0.is_io_open()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Check if keep-alive timeout occured
|
||||||
|
pub fn is_keepalive(&self) -> bool {
|
||||||
|
self.0.flags.get().contains(Flags::DSP_KEEPALIVE)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Check if io stream is closed
|
||||||
|
pub fn is_closed(&self) -> bool {
|
||||||
|
self.0.flags.get().intersects(
|
||||||
|
Flags::IO_ERR
|
||||||
|
| Flags::IO_SHUTDOWN
|
||||||
|
| Flags::IO_CLOSED
|
||||||
|
| Flags::IO_FILTERS
|
||||||
|
| Flags::DSP_STOP,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Take io error if any occured
|
||||||
|
pub fn take_error(&self) -> Option<io::Error> {
|
||||||
|
self.0.error.take()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Wake dispatcher task
|
||||||
|
pub fn wake_dispatcher(&self) {
|
||||||
|
self.0.dispatch_task.wake();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Mark dispatcher as stopped
|
||||||
|
pub fn stop_dispatcher(&self) {
|
||||||
|
self.0.insert_flags(Flags::DSP_STOP);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Gracefully close connection
|
||||||
|
///
|
||||||
|
/// First stop dispatcher, then dispatcher stops io tasks
|
||||||
|
pub fn close(&self) {
|
||||||
|
self.0.insert_flags(Flags::DSP_STOP);
|
||||||
|
self.0.dispatch_task.wake();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Force close connection
|
||||||
|
///
|
||||||
|
/// Dispatcher does not wait for uncompleted responses, but flushes io buffers.
|
||||||
|
pub fn force_close(&self) {
|
||||||
|
log::trace!("force close framed object");
|
||||||
|
self.0.insert_flags(Flags::DSP_STOP | Flags::IO_SHUTDOWN);
|
||||||
|
self.0.read_task.wake();
|
||||||
|
self.0.write_task.wake();
|
||||||
|
self.0.dispatch_task.wake();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Notify when io stream get disconnected
|
||||||
|
pub fn on_disconnect(&self) -> OnDisconnect {
|
||||||
|
OnDisconnect::new(self.0.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Query specific data
|
||||||
|
pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
|
||||||
|
if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
|
||||||
|
types::QueryItem::new(item)
|
||||||
|
} else {
|
||||||
|
types::QueryItem::empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Check if write task is ready
|
||||||
|
pub fn is_write_ready(&self) -> bool {
|
||||||
|
!self.0.flags.get().contains(Flags::WR_BACKPRESSURE)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Check if read buffer has new data
|
||||||
|
pub fn is_read_ready(&self) -> bool {
|
||||||
|
self.0.flags.get().contains(Flags::RD_READY)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Check if write buffer is full
|
||||||
|
pub fn is_write_buf_full(&self) -> bool {
|
||||||
|
if let Some(buf) = self.0.read_buf.take() {
|
||||||
|
let hw = self.memory_pool().write_params_high();
|
||||||
|
let result = buf.len() >= hw;
|
||||||
|
self.0.write_buf.set(Some(buf));
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Check if read buffer is full
|
||||||
|
pub fn is_read_buf_full(&self) -> bool {
|
||||||
|
if let Some(buf) = self.0.read_buf.take() {
|
||||||
|
let result = buf.len() >= self.memory_pool().read_params_high();
|
||||||
|
self.0.read_buf.set(Some(buf));
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Wait until write task flushes data to io stream
|
||||||
|
///
|
||||||
|
/// Write task must be waken up separately.
|
||||||
|
pub fn enable_write_backpressure(&self) {
|
||||||
|
log::trace!("enable write back-pressure");
|
||||||
|
self.0.insert_flags(Flags::WR_BACKPRESSURE);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Get mut access to write buffer
|
||||||
|
pub fn with_write_buf<F, R>(&self, f: F) -> Result<R, io::Error>
|
||||||
|
where
|
||||||
|
F: FnOnce(&mut BytesMut) -> R,
|
||||||
|
{
|
||||||
|
let filter = self.0.filter.get();
|
||||||
|
let mut buf = filter
|
||||||
|
.get_write_buf()
|
||||||
|
.unwrap_or_else(|| self.memory_pool().get_write_buf());
|
||||||
|
if buf.is_empty() {
|
||||||
|
self.0.write_task.wake();
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = f(&mut buf);
|
||||||
|
filter.release_write_buf(buf)?;
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Get mut access to read buffer
|
||||||
|
pub fn with_read_buf<F, R>(&self, f: F) -> R
|
||||||
|
where
|
||||||
|
F: FnOnce(&mut BytesMut) -> R,
|
||||||
|
{
|
||||||
|
let mut buf = self
|
||||||
|
.0
|
||||||
|
.read_buf
|
||||||
|
.take()
|
||||||
|
.unwrap_or_else(|| self.memory_pool().get_read_buf());
|
||||||
|
let res = f(&mut buf);
|
||||||
|
if buf.is_empty() {
|
||||||
|
self.memory_pool().release_read_buf(buf);
|
||||||
|
} else {
|
||||||
|
self.0.read_buf.set(Some(buf));
|
||||||
|
}
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Encode and write item to a buffer and wake up write task
|
||||||
|
///
|
||||||
|
/// Returns write buffer state, false is returned if write buffer if full.
|
||||||
|
pub fn encode<U>(
|
||||||
|
&self,
|
||||||
|
item: U::Item,
|
||||||
|
codec: &U,
|
||||||
|
) -> Result<bool, <U as Encoder>::Error>
|
||||||
|
where
|
||||||
|
U: Encoder,
|
||||||
|
{
|
||||||
|
let flags = self.0.flags.get();
|
||||||
|
|
||||||
|
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||||
|
let filter = self.0.filter.get();
|
||||||
|
let mut buf = filter
|
||||||
|
.get_write_buf()
|
||||||
|
.unwrap_or_else(|| self.memory_pool().get_write_buf());
|
||||||
|
let is_write_sleep = buf.is_empty();
|
||||||
|
let (hw, lw) = self.memory_pool().write_params().unpack();
|
||||||
|
|
||||||
|
// make sure we've got room
|
||||||
|
let remaining = buf.capacity() - buf.len();
|
||||||
|
if remaining < lw {
|
||||||
|
buf.reserve(hw - remaining);
|
||||||
|
}
|
||||||
|
|
||||||
|
// encode item and wake write task
|
||||||
|
let result = codec.encode(item, &mut buf).map(|_| {
|
||||||
|
if is_write_sleep {
|
||||||
|
self.0.write_task.wake();
|
||||||
|
}
|
||||||
|
buf.len() < hw
|
||||||
|
});
|
||||||
|
if let Err(err) = filter.release_write_buf(buf) {
|
||||||
|
self.0.set_error(Some(err));
|
||||||
|
}
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Attempts to decode a frame from the read buffer
|
||||||
|
///
|
||||||
|
/// Read buffer ready state gets cleanup if decoder cannot
|
||||||
|
/// decode any frame.
|
||||||
|
pub fn decode<U>(
|
||||||
|
&self,
|
||||||
|
codec: &U,
|
||||||
|
) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
|
||||||
|
where
|
||||||
|
U: Decoder,
|
||||||
|
{
|
||||||
|
if let Some(mut buf) = self.0.read_buf.take() {
|
||||||
|
let result = codec.decode(&mut buf);
|
||||||
|
self.0.read_buf.set(Some(buf));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
/// Write bytes to a buffer and wake up write task
|
||||||
|
///
|
||||||
|
/// Returns write buffer state, false is returned if write buffer if full.
|
||||||
|
pub fn write(&self, src: &[u8]) -> Result<bool, io::Error> {
|
||||||
|
let flags = self.0.flags.get();
|
||||||
|
|
||||||
|
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||||
|
let filter = self.0.filter.get();
|
||||||
|
let mut buf = filter
|
||||||
|
.get_write_buf()
|
||||||
|
.unwrap_or_else(|| self.memory_pool().get_write_buf());
|
||||||
|
let is_write_sleep = buf.is_empty();
|
||||||
|
|
||||||
|
// write and wake write task
|
||||||
|
buf.extend_from_slice(src);
|
||||||
|
let result = buf.len() < self.memory_pool().write_params_high();
|
||||||
|
if is_write_sleep {
|
||||||
|
self.0.write_task.wake();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(err) = filter.release_write_buf(buf) {
|
||||||
|
self.0.set_error(Some(err));
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
} else {
|
||||||
|
Ok(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for IoRef {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("IoRef")
|
||||||
|
.field("open", &!self.is_closed())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::{cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
|
||||||
|
|
||||||
|
use ntex_bytes::Bytes;
|
||||||
|
use ntex_codec::BytesCodec;
|
||||||
|
use ntex_util::future::{lazy, poll_fn, Ready};
|
||||||
|
use ntex_util::time::{sleep, Millis};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::testing::IoTest;
|
||||||
|
use crate::{Filter, FilterFactory, Io, WriteReadiness};
|
||||||
|
|
||||||
|
const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
|
||||||
|
const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
|
||||||
|
|
||||||
|
#[ntex::test]
|
||||||
|
async fn utils() {
|
||||||
|
let (client, server) = IoTest::create();
|
||||||
|
client.remote_buffer_cap(1024);
|
||||||
|
client.write(TEXT);
|
||||||
|
|
||||||
|
let state = Io::new(server);
|
||||||
|
assert!(!state.is_read_buf_full());
|
||||||
|
assert!(!state.is_write_buf_full());
|
||||||
|
|
||||||
|
let msg = state.next(&BytesCodec).await.unwrap().unwrap();
|
||||||
|
assert_eq!(msg, Bytes::from_static(BIN));
|
||||||
|
|
||||||
|
let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await;
|
||||||
|
assert!(res.is_pending());
|
||||||
|
client.write(TEXT);
|
||||||
|
sleep(Millis(50)).await;
|
||||||
|
let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await;
|
||||||
|
if let Poll::Ready(msg) = res {
|
||||||
|
assert_eq!(msg.unwrap().unwrap(), Bytes::from_static(BIN));
|
||||||
|
}
|
||||||
|
|
||||||
|
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
|
||||||
|
let msg = state.next(&BytesCodec).await;
|
||||||
|
assert!(msg.unwrap().is_err());
|
||||||
|
assert!(state.flags().contains(Flags::IO_ERR));
|
||||||
|
assert!(state.flags().contains(Flags::DSP_STOP));
|
||||||
|
|
||||||
|
let (client, server) = IoTest::create();
|
||||||
|
client.remote_buffer_cap(1024);
|
||||||
|
let state = Io::new(server);
|
||||||
|
|
||||||
|
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
|
||||||
|
let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await;
|
||||||
|
if let Poll::Ready(msg) = res {
|
||||||
|
assert!(msg.unwrap().is_err());
|
||||||
|
assert!(state.flags().contains(Flags::IO_ERR));
|
||||||
|
assert!(state.flags().contains(Flags::DSP_STOP));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (client, server) = IoTest::create();
|
||||||
|
client.remote_buffer_cap(1024);
|
||||||
|
let state = Io::new(server);
|
||||||
|
state
|
||||||
|
.send(Bytes::from_static(b"test"), &BytesCodec)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let buf = client.read().await.unwrap();
|
||||||
|
assert_eq!(buf, Bytes::from_static(b"test"));
|
||||||
|
|
||||||
|
client.write_error(io::Error::new(io::ErrorKind::Other, "err"));
|
||||||
|
let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
|
||||||
|
assert!(res.is_err());
|
||||||
|
assert!(state.flags().contains(Flags::IO_ERR));
|
||||||
|
assert!(state.flags().contains(Flags::DSP_STOP));
|
||||||
|
|
||||||
|
let (client, server) = IoTest::create();
|
||||||
|
client.remote_buffer_cap(1024);
|
||||||
|
let state = Io::new(server);
|
||||||
|
state.force_close();
|
||||||
|
assert!(state.flags().contains(Flags::DSP_STOP));
|
||||||
|
assert!(state.flags().contains(Flags::IO_SHUTDOWN));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[ntex::test]
|
||||||
|
async fn on_disconnect() {
|
||||||
|
let (client, server) = IoTest::create();
|
||||||
|
let state = Io::new(server);
|
||||||
|
let mut waiter = state.on_disconnect();
|
||||||
|
assert_eq!(
|
||||||
|
lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
|
||||||
|
Poll::Pending
|
||||||
|
);
|
||||||
|
let mut waiter2 = waiter.clone();
|
||||||
|
assert_eq!(
|
||||||
|
lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await,
|
||||||
|
Poll::Pending
|
||||||
|
);
|
||||||
|
client.close().await;
|
||||||
|
assert_eq!(waiter.await, ());
|
||||||
|
assert_eq!(waiter2.await, ());
|
||||||
|
|
||||||
|
let mut waiter = state.on_disconnect();
|
||||||
|
assert_eq!(
|
||||||
|
lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
|
||||||
|
Poll::Ready(())
|
||||||
|
);
|
||||||
|
|
||||||
|
let (client, server) = IoTest::create();
|
||||||
|
let state = Io::new(server);
|
||||||
|
let mut waiter = state.on_disconnect();
|
||||||
|
assert_eq!(
|
||||||
|
lazy(|cx| Pin::new(&mut waiter).poll(cx)).await,
|
||||||
|
Poll::Pending
|
||||||
|
);
|
||||||
|
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
|
||||||
|
assert_eq!(waiter.await, ());
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Counter<F> {
|
||||||
|
inner: F,
|
||||||
|
in_bytes: Rc<Cell<usize>>,
|
||||||
|
out_bytes: Rc<Cell<usize>>,
|
||||||
|
}
|
||||||
|
impl<F: Filter> Filter for Counter<F> {
|
||||||
|
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn query(&self, _: std::any::TypeId) -> Option<Box<dyn std::any::Any>> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
|
||||||
|
self.inner.poll_read_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn closed(&self, err: Option<io::Error>) {
|
||||||
|
self.inner.closed(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||||
|
self.inner.get_read_buf()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn release_read_buf(
|
||||||
|
&self,
|
||||||
|
buf: BytesMut,
|
||||||
|
new_bytes: usize,
|
||||||
|
) -> Result<(), io::Error> {
|
||||||
|
self.in_bytes.set(self.in_bytes.get() + new_bytes);
|
||||||
|
self.inner.release_read_buf(buf, new_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_ready(
|
||||||
|
&self,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), WriteReadiness>> {
|
||||||
|
self.inner.poll_write_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_write_buf(&self) -> Option<BytesMut> {
|
||||||
|
if let Some(buf) = self.inner.get_write_buf() {
|
||||||
|
self.out_bytes.set(self.out_bytes.get() - buf.len());
|
||||||
|
Some(buf)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
|
||||||
|
self.out_bytes.set(self.out_bytes.get() + buf.len());
|
||||||
|
self.inner.release_write_buf(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CounterFactory(Rc<Cell<usize>>, Rc<Cell<usize>>);
|
||||||
|
|
||||||
|
impl<F: Filter> FilterFactory<F> for CounterFactory {
|
||||||
|
type Filter = Counter<F>;
|
||||||
|
|
||||||
|
type Error = ();
|
||||||
|
type Future = Ready<Io<Counter<F>>, Self::Error>;
|
||||||
|
|
||||||
|
fn create(self, io: Io<F>) -> Self::Future {
|
||||||
|
let in_bytes = self.0.clone();
|
||||||
|
let out_bytes = self.1.clone();
|
||||||
|
Ready::Ok(
|
||||||
|
io.map_filter(|inner| {
|
||||||
|
Ok::<_, ()>(Counter {
|
||||||
|
inner,
|
||||||
|
in_bytes,
|
||||||
|
out_bytes,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[ntex::test]
|
||||||
|
async fn filter() {
|
||||||
|
let in_bytes = Rc::new(Cell::new(0));
|
||||||
|
let out_bytes = Rc::new(Cell::new(0));
|
||||||
|
let factory = CounterFactory(in_bytes.clone(), out_bytes.clone());
|
||||||
|
|
||||||
|
let (client, server) = IoTest::create();
|
||||||
|
let state = Io::new(server).add_filter(factory).await.unwrap();
|
||||||
|
|
||||||
|
client.remote_buffer_cap(1024);
|
||||||
|
client.write(TEXT);
|
||||||
|
let msg = state.next(&BytesCodec).await.unwrap().unwrap();
|
||||||
|
assert_eq!(msg, Bytes::from_static(BIN));
|
||||||
|
|
||||||
|
state
|
||||||
|
.send(Bytes::from_static(b"test"), &BytesCodec)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let buf = client.read().await.unwrap();
|
||||||
|
assert_eq!(buf, Bytes::from_static(b"test"));
|
||||||
|
|
||||||
|
assert_eq!(in_bytes.get(), BIN.len());
|
||||||
|
assert_eq!(out_bytes.get(), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[ntex::test]
|
||||||
|
async fn boxed_filter() {
|
||||||
|
let in_bytes = Rc::new(Cell::new(0));
|
||||||
|
let out_bytes = Rc::new(Cell::new(0));
|
||||||
|
|
||||||
|
let (client, server) = IoTest::create();
|
||||||
|
let state = Io::new(server)
|
||||||
|
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let state = state.into_boxed();
|
||||||
|
|
||||||
|
client.remote_buffer_cap(1024);
|
||||||
|
client.write(TEXT);
|
||||||
|
let msg = state.next(&BytesCodec).await.unwrap().unwrap();
|
||||||
|
assert_eq!(msg, Bytes::from_static(BIN));
|
||||||
|
|
||||||
|
state
|
||||||
|
.send(Bytes::from_static(b"test"), &BytesCodec)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let buf = client.read().await.unwrap();
|
||||||
|
assert_eq!(buf, Bytes::from_static(b"test"));
|
||||||
|
|
||||||
|
assert_eq!(in_bytes.get(), BIN.len() * 2);
|
||||||
|
assert_eq!(out_bytes.get(), 8);
|
||||||
|
|
||||||
|
// refs
|
||||||
|
assert_eq!(Rc::strong_count(&in_bytes), 3);
|
||||||
|
drop(state);
|
||||||
|
assert_eq!(Rc::strong_count(&in_bytes), 1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,11 +1,15 @@
|
||||||
use std::{any::Any, any::TypeId, fmt, future::Future, io, task::Context, task::Poll};
|
use std::{
|
||||||
|
any::Any, any::TypeId, fmt, future::Future, io::Error as IoError, task::Context,
|
||||||
|
task::Poll,
|
||||||
|
};
|
||||||
|
|
||||||
pub mod testing;
|
pub mod testing;
|
||||||
pub mod types;
|
pub mod types;
|
||||||
|
|
||||||
mod dispatcher;
|
mod dispatcher;
|
||||||
mod filter;
|
mod filter;
|
||||||
mod state;
|
mod io;
|
||||||
|
mod ioref;
|
||||||
mod tasks;
|
mod tasks;
|
||||||
mod time;
|
mod time;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
@ -21,7 +25,7 @@ use ntex_util::time::Millis;
|
||||||
|
|
||||||
pub use self::dispatcher::Dispatcher;
|
pub use self::dispatcher::Dispatcher;
|
||||||
pub use self::filter::DefaultFilter;
|
pub use self::filter::DefaultFilter;
|
||||||
pub use self::state::{Io, IoRef, OnDisconnect, ReadRef, WriteRef};
|
pub use self::io::{Io, IoRef, OnDisconnect};
|
||||||
pub use self::tasks::{ReadContext, WriteContext};
|
pub use self::tasks::{ReadContext, WriteContext};
|
||||||
pub use self::time::Timer;
|
pub use self::time::Timer;
|
||||||
|
|
||||||
|
@ -37,9 +41,9 @@ pub enum WriteReadiness {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Filter: 'static {
|
pub trait Filter: 'static {
|
||||||
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>>;
|
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), IoError>>;
|
||||||
|
|
||||||
fn closed(&self, err: Option<io::Error>);
|
fn closed(&self, err: Option<IoError>);
|
||||||
|
|
||||||
fn query(&self, id: TypeId) -> Option<Box<dyn Any>>;
|
fn query(&self, id: TypeId) -> Option<Box<dyn Any>>;
|
||||||
|
|
||||||
|
@ -52,9 +56,9 @@ pub trait Filter: 'static {
|
||||||
|
|
||||||
fn get_write_buf(&self) -> Option<BytesMut>;
|
fn get_write_buf(&self) -> Option<BytesMut>;
|
||||||
|
|
||||||
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error>;
|
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), IoError>;
|
||||||
|
|
||||||
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>;
|
fn release_write_buf(&self, buf: BytesMut) -> Result<(), IoError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait FilterFactory<F: Filter>: Sized {
|
pub trait FilterFactory<F: Filter>: Sized {
|
||||||
|
@ -88,7 +92,7 @@ pub enum DispatchItem<U: Encoder + Decoder> {
|
||||||
/// Encoder parse error
|
/// Encoder parse error
|
||||||
EncoderError(<U as Encoder>::Error),
|
EncoderError(<U as Encoder>::Error),
|
||||||
/// Socket is disconnected
|
/// Socket is disconnected
|
||||||
Disconnect(Option<io::Error>),
|
Disconnect(Option<IoError>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<U> fmt::Debug for DispatchItem<U>
|
impl<U> fmt::Debug for DispatchItem<U>
|
||||||
|
@ -134,6 +138,7 @@ pub mod rt {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use ntex_codec::BytesCodec;
|
use ntex_codec::BytesCodec;
|
||||||
|
use std::io;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_fmt() {
|
fn test_fmt() {
|
||||||
|
|
1255
ntex-io/src/state.rs
1255
ntex-io/src/state.rs
File diff suppressed because it is too large
Load diff
|
@ -2,7 +2,7 @@ use std::{io, task::Context, task::Poll};
|
||||||
|
|
||||||
use ntex_bytes::{BytesMut, PoolRef};
|
use ntex_bytes::{BytesMut, PoolRef};
|
||||||
|
|
||||||
use super::{state::Flags, IoRef, WriteReadiness};
|
use super::{io::Flags, IoRef, WriteReadiness};
|
||||||
|
|
||||||
pub struct ReadContext(pub(super) IoRef);
|
pub struct ReadContext(pub(super) IoRef);
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ impl ReadContext {
|
||||||
flags.insert(Flags::RD_READY);
|
flags.insert(Flags::RD_READY);
|
||||||
self.0.set_flags(flags);
|
self.0.set_flags(flags);
|
||||||
self.0 .0.dispatch_task.wake();
|
self.0 .0.dispatch_task.wake();
|
||||||
|
log::trace!("new {} bytes available, wakeup dispatcher", new_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.0.filter().release_read_buf(buf, new_bytes)?;
|
self.0.filter().release_read_buf(buf, new_bytes)?;
|
||||||
|
|
|
@ -409,7 +409,10 @@ impl Future for ReadTask {
|
||||||
|
|
||||||
match io.poll_read_buf(cx, &mut buf) {
|
match io.poll_read_buf(cx, &mut buf) {
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
log::trace!("no more data in io stream");
|
log::trace!(
|
||||||
|
"no more data in io stream, read: {:?}",
|
||||||
|
new_bytes
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Poll::Ready(Ok(n)) => {
|
Poll::Ready(Ok(n)) => {
|
||||||
|
|
|
@ -4,15 +4,14 @@ use std::{
|
||||||
|
|
||||||
use ntex_util::time::{now, sleep, Millis};
|
use ntex_util::time::{now, sleep, Millis};
|
||||||
|
|
||||||
use crate::rt::spawn;
|
use crate::{io::IoState, rt::spawn, IoRef};
|
||||||
use crate::state::{IoRef, IoStateInner};
|
|
||||||
|
|
||||||
pub struct Timer(Rc<RefCell<Inner>>);
|
pub struct Timer(Rc<RefCell<Inner>>);
|
||||||
|
|
||||||
struct Inner {
|
struct Inner {
|
||||||
running: bool,
|
running: bool,
|
||||||
resolution: Millis,
|
resolution: Millis,
|
||||||
notifications: BTreeMap<Instant, HashSet<Rc<IoStateInner>, fxhash::FxBuildHasher>>,
|
notifications: BTreeMap<Instant, HashSet<Rc<IoState>, fxhash::FxBuildHasher>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Inner {
|
impl Inner {
|
||||||
|
|
|
@ -355,21 +355,17 @@ impl<F: Filter> AsyncRead for Io<F> {
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
buf: &mut ReadBuf<'_>,
|
buf: &mut ReadBuf<'_>,
|
||||||
) -> Poll<io::Result<()>> {
|
) -> Poll<io::Result<()>> {
|
||||||
let read = self.read();
|
let len = self.with_read_buf(|src| {
|
||||||
let len = read.with_buf(|src| {
|
|
||||||
let len = cmp::min(src.len(), buf.capacity());
|
let len = cmp::min(src.len(), buf.capacity());
|
||||||
buf.put_slice(&src.split_to(len));
|
buf.put_slice(&src.split_to(len));
|
||||||
len
|
len
|
||||||
});
|
});
|
||||||
|
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
match read.poll_read_ready(cx) {
|
match self.poll_read_ready(cx) {
|
||||||
Ok(()) => Poll::Pending,
|
Poll::Pending | Poll::Ready(Some(Ok(()))) => Poll::Pending,
|
||||||
Err(Some(e)) => Poll::Ready(Err(e)),
|
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(e)),
|
||||||
Err(None) => Poll::Ready(Err(io::Error::new(
|
Poll::Ready(None) => Poll::Ready(Ok(())),
|
||||||
io::ErrorKind::Other,
|
|
||||||
"disconnected",
|
|
||||||
))),
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
|
@ -383,18 +379,18 @@ impl<F: Filter> AsyncWrite for Io<F> {
|
||||||
_: &mut Context<'_>,
|
_: &mut Context<'_>,
|
||||||
buf: &[u8],
|
buf: &[u8],
|
||||||
) -> Poll<io::Result<usize>> {
|
) -> Poll<io::Result<usize>> {
|
||||||
Poll::Ready(self.write().write(buf).map(|_| buf.len()))
|
Poll::Ready(self.write(buf).map(|_| buf.len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
self.write().poll_write_ready(cx, false)
|
self.poll_write_ready(cx, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(
|
fn poll_shutdown(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<io::Result<()>> {
|
) -> Poll<io::Result<()>> {
|
||||||
self.0.poll_shutdown(cx)
|
Io::poll_shutdown(&*self, cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -404,21 +400,17 @@ impl AsyncRead for IoBoxed {
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
buf: &mut ReadBuf<'_>,
|
buf: &mut ReadBuf<'_>,
|
||||||
) -> Poll<io::Result<()>> {
|
) -> Poll<io::Result<()>> {
|
||||||
let read = self.read();
|
let len = self.with_read_buf(|src| {
|
||||||
let len = read.with_buf(|src| {
|
|
||||||
let len = cmp::min(src.len(), buf.capacity());
|
let len = cmp::min(src.len(), buf.capacity());
|
||||||
buf.put_slice(&src.split_to(len));
|
buf.put_slice(&src.split_to(len));
|
||||||
len
|
len
|
||||||
});
|
});
|
||||||
|
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
match read.poll_read_ready(cx) {
|
match self.poll_read_ready(cx) {
|
||||||
Ok(()) => Poll::Pending,
|
Poll::Pending | Poll::Ready(Some(Ok(()))) => Poll::Pending,
|
||||||
Err(Some(e)) => Poll::Ready(Err(e)),
|
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(e)),
|
||||||
Err(None) => Poll::Ready(Err(io::Error::new(
|
Poll::Ready(None) => Poll::Ready(Ok(())),
|
||||||
io::ErrorKind::Other,
|
|
||||||
"disconnected",
|
|
||||||
))),
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
|
@ -432,18 +424,18 @@ impl AsyncWrite for IoBoxed {
|
||||||
_: &mut Context<'_>,
|
_: &mut Context<'_>,
|
||||||
buf: &[u8],
|
buf: &[u8],
|
||||||
) -> Poll<io::Result<usize>> {
|
) -> Poll<io::Result<usize>> {
|
||||||
Poll::Ready(self.write().write(buf).map(|_| buf.len()))
|
Poll::Ready(self.write(buf).map(|_| buf.len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
self.write().poll_write_ready(cx, false)
|
self.poll_write_ready(cx, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(
|
fn poll_shutdown(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<io::Result<()>> {
|
) -> Poll<io::Result<()>> {
|
||||||
self.0.poll_shutdown(cx)
|
Self::poll_shutdown(&*self, cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "ntex-tls"
|
name = "ntex-tls"
|
||||||
version = "0.1.0-b.0"
|
version = "0.1.0-b.1"
|
||||||
authors = ["ntex contributors <team@ntex.rs>"]
|
authors = ["ntex contributors <team@ntex.rs>"]
|
||||||
description = "An implementation of SSL streams for ntex backed by OpenSSL"
|
description = "An implementation of SSL streams for ntex backed by OpenSSL"
|
||||||
keywords = ["network", "framework", "async", "futures"]
|
keywords = ["network", "framework", "async", "futures"]
|
||||||
|
|
|
@ -27,7 +27,7 @@ async fn main() -> io::Result<()> {
|
||||||
let resp = io
|
let resp = io
|
||||||
.next(&codec::BytesCodec)
|
.next(&codec::BytesCodec)
|
||||||
.await
|
.await
|
||||||
.map_err(Either::into_inner)?;
|
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "disconnected"))?;
|
||||||
|
|
||||||
println!("disconnecting");
|
println!("disconnecting");
|
||||||
io.shutdown().await
|
io.shutdown().await
|
||||||
|
|
|
@ -49,17 +49,17 @@ async fn main() -> io::Result<()> {
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match io.next(&codec::BytesCodec).await {
|
match io.next(&codec::BytesCodec).await {
|
||||||
Ok(Some(msg)) => {
|
Some(Ok(msg)) => {
|
||||||
println!("Got message: {:?}", msg);
|
println!("Got message: {:?}", msg);
|
||||||
io.send(msg.freeze(), &codec::BytesCodec)
|
io.send(msg.freeze(), &codec::BytesCodec)
|
||||||
.await
|
.await
|
||||||
.map_err(Either::into_inner)?;
|
.map_err(Either::into_inner)?;
|
||||||
}
|
}
|
||||||
Ok(None) => break,
|
Some(Err(e)) => {
|
||||||
Err(e) => {
|
|
||||||
println!("Got error: {:?}", e);
|
println!("Got error: {:?}", e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
None => break,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!("Client is disconnected");
|
println!("Client is disconnected");
|
||||||
|
|
|
@ -15,10 +15,10 @@ async fn main() -> io::Result<()> {
|
||||||
// load ssl keys
|
// load ssl keys
|
||||||
let mut builder = ssl::SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
|
let mut builder = ssl::SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();
|
||||||
builder
|
builder
|
||||||
.set_private_key_file("../tests/key.pem", SslFiletype::PEM)
|
.set_private_key_file("./examples/key.pem", SslFiletype::PEM)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
builder
|
builder
|
||||||
.set_certificate_chain_file("../tests/cert.pem")
|
.set_certificate_chain_file("./examples/cert.pem")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let acceptor = builder.build();
|
let acceptor = builder.build();
|
||||||
|
|
||||||
|
@ -30,17 +30,17 @@ async fn main() -> io::Result<()> {
|
||||||
println!("New client is connected");
|
println!("New client is connected");
|
||||||
loop {
|
loop {
|
||||||
match io.next(&codec::BytesCodec).await {
|
match io.next(&codec::BytesCodec).await {
|
||||||
Ok(Some(msg)) => {
|
Some(Ok(msg)) => {
|
||||||
println!("Got message: {:?}", msg);
|
println!("Got message: {:?}", msg);
|
||||||
io.send(msg.freeze(), &codec::BytesCodec)
|
io.send(msg.freeze(), &codec::BytesCodec)
|
||||||
.await
|
.await
|
||||||
.map_err(Either::into_inner)?;
|
.map_err(Either::into_inner)?;
|
||||||
}
|
}
|
||||||
Ok(None) => break,
|
Some(Err(e)) => {
|
||||||
Err(e) => {
|
|
||||||
println!("Got error: {:?}", e);
|
println!("Got error: {:?}", e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
None => break,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!("Client is disconnected");
|
println!("Client is disconnected");
|
||||||
|
|
|
@ -7,7 +7,7 @@ use std::{
|
||||||
|
|
||||||
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
||||||
use ntex_io::{Filter, FilterFactory, Io, IoRef, WriteReadiness};
|
use ntex_io::{Filter, FilterFactory, Io, IoRef, WriteReadiness};
|
||||||
use ntex_util::{future::poll_fn, time, time::Millis};
|
use ntex_util::{future::poll_fn, ready, time, time::Millis};
|
||||||
use tls_openssl::ssl::{self, SslStream};
|
use tls_openssl::ssl::{self, SslStream};
|
||||||
|
|
||||||
mod accept;
|
mod accept;
|
||||||
|
@ -331,26 +331,26 @@ impl<F: Filter + 'static> FilterFactory<F> for SslConnector {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_result<T>(
|
fn handle_result<T, F>(
|
||||||
result: Result<T, ssl::Error>,
|
result: Result<T, ssl::Error>,
|
||||||
st: &IoRef,
|
io: &Io<F>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Result<T, Box<dyn Error>>> {
|
) -> Poll<Result<T, Box<dyn Error>>> {
|
||||||
match result {
|
match result {
|
||||||
Ok(v) => Poll::Ready(Ok(v)),
|
Ok(v) => Poll::Ready(Ok(v)),
|
||||||
Err(e) => match e.code() {
|
Err(e) => match e.code() {
|
||||||
ssl::ErrorCode::WANT_READ => {
|
ssl::ErrorCode::WANT_READ => {
|
||||||
if let Err(e) = st.read().poll_read_ready(cx) {
|
match ready!(io.poll_read_ready(cx)) {
|
||||||
let e = e.unwrap_or_else(|| {
|
None => Err::<_, Box<dyn Error>>(
|
||||||
io::Error::new(io::ErrorKind::Other, "disconnected")
|
io::Error::new(io::ErrorKind::Other, "disconnected").into(),
|
||||||
});
|
),
|
||||||
Poll::Ready(Err(e.into()))
|
Some(Err(err)) => Err(err.into()),
|
||||||
} else {
|
_ => Ok(()),
|
||||||
Poll::Pending
|
}?;
|
||||||
}
|
Poll::Pending
|
||||||
}
|
}
|
||||||
ssl::ErrorCode::WANT_WRITE => {
|
ssl::ErrorCode::WANT_WRITE => {
|
||||||
let _ = st.write().poll_write_ready(cx, true)?;
|
let _ = io.poll_write_ready(cx, true)?;
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
_ => Poll::Ready(Err(Box::new(e))),
|
_ => Poll::Ready(Err(Box::new(e))),
|
||||||
|
|
|
@ -5,7 +5,7 @@ use std::{any, cell::RefCell, cmp, task::Context, task::Poll};
|
||||||
|
|
||||||
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
||||||
use ntex_io::{Filter, Io, IoRef, WriteReadiness};
|
use ntex_io::{Filter, Io, IoRef, WriteReadiness};
|
||||||
use ntex_util::future::poll_fn;
|
use ntex_util::{future::poll_fn, ready};
|
||||||
use tls_rust::{ClientConfig, ClientConnection, ServerName};
|
use tls_rust::{ClientConfig, ClientConnection, ServerName};
|
||||||
|
|
||||||
use super::TlsFilter;
|
use super::TlsFilter;
|
||||||
|
@ -207,16 +207,16 @@ impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
|
||||||
|
|
||||||
impl<F: Filter> TlsClientFilter<F> {
|
impl<F: Filter> TlsClientFilter<F> {
|
||||||
pub(crate) async fn create(
|
pub(crate) async fn create(
|
||||||
st: Io<F>,
|
io: Io<F>,
|
||||||
cfg: Arc<ClientConfig>,
|
cfg: Arc<ClientConfig>,
|
||||||
domain: ServerName,
|
domain: ServerName,
|
||||||
) -> Result<Io<TlsFilter<F>>, io::Error> {
|
) -> Result<Io<TlsFilter<F>>, io::Error> {
|
||||||
let pool = st.memory_pool();
|
let pool = io.memory_pool();
|
||||||
let session = match ClientConnection::new(cfg, domain) {
|
let session = match ClientConnection::new(cfg, domain) {
|
||||||
Ok(session) => session,
|
Ok(session) => session,
|
||||||
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
||||||
};
|
};
|
||||||
let st = st.map_filter(|inner: F| {
|
let io = io.map_filter(|inner: F| {
|
||||||
let inner = IoInner {
|
let inner = IoInner {
|
||||||
pool,
|
pool,
|
||||||
inner,
|
inner,
|
||||||
|
@ -230,59 +230,49 @@ impl<F: Filter> TlsClientFilter<F> {
|
||||||
}))
|
}))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let filter = st.filter();
|
let filter = io.filter();
|
||||||
let read = st.read();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let (result, wants_read) = {
|
let (result, wants_read) = {
|
||||||
let mut session = filter.client().session.borrow_mut();
|
let mut session = filter.client().session.borrow_mut();
|
||||||
let mut inner = filter.client().inner.borrow_mut();
|
let mut inner = filter.client().inner.borrow_mut();
|
||||||
let mut io = Wrapper(&mut *inner);
|
let mut wrp = Wrapper(&mut *inner);
|
||||||
let result = session.complete_io(&mut io);
|
let result = session.complete_io(&mut wrp);
|
||||||
let wants_read = session.wants_read();
|
let wants_read = session.wants_read();
|
||||||
|
|
||||||
if session.wants_write() {
|
if session.wants_write() {
|
||||||
loop {
|
loop {
|
||||||
let n = session.write_tls(&mut io)?;
|
let n = session.write_tls(&mut wrp)?;
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if result.is_ok() && wants_read {
|
|
||||||
poll_fn(|cx| {
|
|
||||||
read.poll_read_ready(cx).map_err(|e| {
|
|
||||||
e.unwrap_or_else(|| {
|
|
||||||
io::Error::new(io::ErrorKind::Other, "disconnected")
|
|
||||||
})
|
|
||||||
})?;
|
|
||||||
Poll::Ready(Ok::<_, io::Error>(()))
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
(result, wants_read)
|
(result, wants_read)
|
||||||
};
|
};
|
||||||
|
if result.is_ok() && wants_read {
|
||||||
|
poll_fn(|cx| match ready!(io.poll_read_ready(cx)) {
|
||||||
|
None => Poll::Ready(Err(io::Error::new(
|
||||||
|
io::ErrorKind::Other,
|
||||||
|
"disconnected",
|
||||||
|
))),
|
||||||
|
Some(Err(e)) => Poll::Ready(Err(e)),
|
||||||
|
_ => Poll::Ready(Ok(())),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
match result {
|
match result {
|
||||||
Ok(_) => return Ok(st),
|
Ok(_) => return Ok(io),
|
||||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||||
if wants_read {
|
|
||||||
read.take_readiness();
|
|
||||||
}
|
|
||||||
poll_fn(|cx| {
|
poll_fn(|cx| {
|
||||||
let read_ready = if wants_read {
|
let read_ready = if wants_read {
|
||||||
if read.is_ready() {
|
match ready!(io.poll_read_ready(cx)) {
|
||||||
true
|
None => Err(io::Error::new(
|
||||||
} else {
|
io::ErrorKind::Other,
|
||||||
read.poll_read_ready(cx).map_err(|e| {
|
"disconnected",
|
||||||
e.unwrap_or_else(|| {
|
)),
|
||||||
io::Error::new(
|
Some(Err(e)) => Err(e),
|
||||||
io::ErrorKind::Other,
|
Some(Ok(_)) => Ok(true),
|
||||||
"disconnected",
|
}?
|
||||||
)
|
|
||||||
})
|
|
||||||
})?;
|
|
||||||
false
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
true
|
true
|
||||||
};
|
};
|
||||||
|
|
|
@ -5,7 +5,7 @@ use std::{any, cell::RefCell, cmp, task::Context, task::Poll};
|
||||||
|
|
||||||
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
||||||
use ntex_io::{Filter, Io, IoRef, WriteReadiness};
|
use ntex_io::{Filter, Io, IoRef, WriteReadiness};
|
||||||
use ntex_util::{future::poll_fn, time, time::Millis};
|
use ntex_util::{future::poll_fn, ready, time, time::Millis};
|
||||||
use tls_rust::{ServerConfig, ServerConnection};
|
use tls_rust::{ServerConfig, ServerConnection};
|
||||||
|
|
||||||
use crate::{rustls::TlsFilter, types};
|
use crate::{rustls::TlsFilter, types};
|
||||||
|
@ -91,14 +91,15 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
||||||
|
|
||||||
fn release_read_buf(&self, mut src: BytesMut, _nb: usize) -> Result<(), io::Error> {
|
fn release_read_buf(&self, mut src: BytesMut, _nb: usize) -> Result<(), io::Error> {
|
||||||
let mut session = self.session.borrow_mut();
|
let mut session = self.session.borrow_mut();
|
||||||
|
let mut inner = self.inner.borrow_mut();
|
||||||
|
|
||||||
if session.is_handshaking() {
|
if session.is_handshaking() {
|
||||||
self.inner.borrow_mut().read_buf = Some(src);
|
inner.read_buf = Some(src);
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
if src.is_empty() {
|
if src.is_empty() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
let mut inner = self.inner.borrow_mut();
|
|
||||||
let (hw, lw) = inner.pool.read_params().unpack();
|
let (hw, lw) = inner.pool.read_params().unpack();
|
||||||
|
|
||||||
// get inner filter buffer
|
// get inner filter buffer
|
||||||
|
@ -161,9 +162,8 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !src.is_empty() {
|
if !src.is_empty() {
|
||||||
self.inner.borrow_mut().write_buf = Some(src);
|
inner.write_buf = Some(src);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -206,17 +206,17 @@ impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
|
||||||
|
|
||||||
impl<F: Filter> TlsServerFilter<F> {
|
impl<F: Filter> TlsServerFilter<F> {
|
||||||
pub(crate) async fn create(
|
pub(crate) async fn create(
|
||||||
st: Io<F>,
|
io: Io<F>,
|
||||||
cfg: Arc<ServerConfig>,
|
cfg: Arc<ServerConfig>,
|
||||||
timeout: Millis,
|
timeout: Millis,
|
||||||
) -> Result<Io<TlsFilter<F>>, io::Error> {
|
) -> Result<Io<TlsFilter<F>>, io::Error> {
|
||||||
time::timeout(timeout, async {
|
time::timeout(timeout, async {
|
||||||
let pool = st.memory_pool();
|
let pool = io.memory_pool();
|
||||||
let session = match ServerConnection::new(cfg) {
|
let session = match ServerConnection::new(cfg) {
|
||||||
Ok(session) => session,
|
Ok(session) => session,
|
||||||
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
||||||
};
|
};
|
||||||
let st = st.map_filter(|inner: F| {
|
let io = io.map_filter(|inner: F| {
|
||||||
let inner = IoInner {
|
let inner = IoInner {
|
||||||
pool,
|
pool,
|
||||||
inner,
|
inner,
|
||||||
|
@ -230,59 +230,51 @@ impl<F: Filter> TlsServerFilter<F> {
|
||||||
}))
|
}))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let filter = st.filter();
|
let filter = io.filter();
|
||||||
let read = st.read();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let (result, wants_read) = {
|
let (result, wants_read) = {
|
||||||
let mut session = filter.server().session.borrow_mut();
|
let mut session = filter.server().session.borrow_mut();
|
||||||
let mut inner = filter.server().inner.borrow_mut();
|
let mut inner = filter.server().inner.borrow_mut();
|
||||||
let mut io = Wrapper(&mut *inner);
|
let mut wrp = Wrapper(&mut *inner);
|
||||||
let result = session.complete_io(&mut io);
|
let result = session.complete_io(&mut wrp);
|
||||||
let wants_read = session.wants_read();
|
let wants_read = session.wants_read();
|
||||||
|
|
||||||
if session.wants_write() {
|
if session.wants_write() {
|
||||||
loop {
|
loop {
|
||||||
let n = session.write_tls(&mut io)?;
|
let n = session.write_tls(&mut wrp)?;
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if result.is_ok() && wants_read {
|
|
||||||
poll_fn(|cx| {
|
|
||||||
read.poll_read_ready(cx).map_err(|e| {
|
|
||||||
e.unwrap_or_else(|| {
|
|
||||||
io::Error::new(io::ErrorKind::Other, "disconnected")
|
|
||||||
})
|
|
||||||
})?;
|
|
||||||
Poll::Ready(Ok::<_, io::Error>(()))
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
(result, wants_read)
|
(result, wants_read)
|
||||||
};
|
};
|
||||||
|
if result.is_ok() && wants_read {
|
||||||
|
poll_fn(|cx| {
|
||||||
|
match ready!(io.poll_read_ready(cx)) {
|
||||||
|
None => {
|
||||||
|
Err(io::Error::new(io::ErrorKind::Other, "disconnected"))
|
||||||
|
}
|
||||||
|
Some(Err(e)) => Err(e),
|
||||||
|
_ => Ok(()),
|
||||||
|
}?;
|
||||||
|
Poll::Ready(Ok::<_, io::Error>(()))
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
match result {
|
match result {
|
||||||
Ok(_) => return Ok(st),
|
Ok(_) => return Ok(io),
|
||||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||||
if wants_read {
|
|
||||||
read.take_readiness();
|
|
||||||
}
|
|
||||||
poll_fn(|cx| {
|
poll_fn(|cx| {
|
||||||
let read_ready = if wants_read {
|
let read_ready = if wants_read {
|
||||||
if read.is_ready() {
|
match ready!(io.poll_read_ready(cx)) {
|
||||||
true
|
None => Err(io::Error::new(
|
||||||
} else {
|
io::ErrorKind::Other,
|
||||||
read.poll_read_ready(cx).map_err(|e| {
|
"disconnected",
|
||||||
e.unwrap_or_else(|| {
|
)),
|
||||||
io::Error::new(
|
Some(Err(e)) => Err(e),
|
||||||
io::ErrorKind::Other,
|
Some(Ok(_)) => Ok(true),
|
||||||
"disconnected",
|
}?
|
||||||
)
|
|
||||||
})
|
|
||||||
})?;
|
|
||||||
false
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
true
|
true
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
# Changes
|
# Changes
|
||||||
|
|
||||||
## [0.5.0-b.0] - 2021-12-xx
|
## [0.5.0-b.1] - 2021-12-20
|
||||||
|
|
||||||
|
* Refactor http/1 dispatcher
|
||||||
|
|
||||||
|
## [0.5.0-b.0] - 2021-12-19
|
||||||
|
|
||||||
* Migrate io to ntex-io
|
* Migrate io to ntex-io
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "ntex"
|
name = "ntex"
|
||||||
version = "0.5.0-b.0"
|
version = "0.5.0-b.1"
|
||||||
authors = ["ntex contributors <team@ntex.rs>"]
|
authors = ["ntex contributors <team@ntex.rs>"]
|
||||||
description = "Framework for composable network services"
|
description = "Framework for composable network services"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -45,8 +45,8 @@ ntex-service = "0.2.1"
|
||||||
ntex-macros = "0.1.3"
|
ntex-macros = "0.1.3"
|
||||||
ntex-util = "0.1.3"
|
ntex-util = "0.1.3"
|
||||||
ntex-bytes = "0.1.8"
|
ntex-bytes = "0.1.8"
|
||||||
ntex-tls = "0.1.0-b.0"
|
ntex-tls = "=0.1.0-b.1"
|
||||||
ntex-io = "0.1.0-b.1"
|
ntex-io = "=0.1.0-b.2"
|
||||||
ntex-rt = { version = "0.4.0-b.0", default-features = false, features = ["tokio"] }
|
ntex-rt = { version = "0.4.0-b.0", default-features = false, features = ["tokio"] }
|
||||||
|
|
||||||
base64 = "0.13"
|
base64 = "0.13"
|
||||||
|
|
|
@ -52,7 +52,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
log::trace!(
|
log::trace!(
|
||||||
"sending http1 request {:#?} body size: {:?}",
|
"sending http1 request {:?} body size: {:?}",
|
||||||
head,
|
head,
|
||||||
body.size()
|
body.size()
|
||||||
);
|
);
|
||||||
|
@ -74,9 +74,10 @@ where
|
||||||
log::trace!("reading http1 response");
|
log::trace!("reading http1 response");
|
||||||
|
|
||||||
// read response and init read body
|
// read response and init read body
|
||||||
let head = if let Some(result) = io.next(&codec).await? {
|
let head = if let Some(result) = io.next(&codec).await {
|
||||||
|
let result = result?;
|
||||||
log::trace!(
|
log::trace!(
|
||||||
"http1 response is received, type: {:?}, response: {:?}",
|
"http1 response is received, type: {:?}, response: {:#?}",
|
||||||
codec.message_type(),
|
codec.message_type(),
|
||||||
result
|
result
|
||||||
);
|
);
|
||||||
|
@ -107,8 +108,8 @@ pub(super) async fn open_tunnel(
|
||||||
io.send((head, BodySize::None).into(), &codec).await?;
|
io.send((head, BodySize::None).into(), &codec).await?;
|
||||||
|
|
||||||
// read response
|
// read response
|
||||||
if let Some(head) = io.next(&codec).await? {
|
if let Some(head) = io.next(&codec).await {
|
||||||
Ok((head, io, codec))
|
Ok((head?, io, codec))
|
||||||
} else {
|
} else {
|
||||||
Err(SendRequestError::from(ConnectError::Disconnected))
|
Err(SendRequestError::from(ConnectError::Disconnected))
|
||||||
}
|
}
|
||||||
|
@ -123,22 +124,20 @@ pub(super) async fn send_body<B>(
|
||||||
where
|
where
|
||||||
B: MessageBody,
|
B: MessageBody,
|
||||||
{
|
{
|
||||||
let wrt = io.write();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match poll_fn(|cx| body.poll_next_chunk(cx)).await {
|
match poll_fn(|cx| body.poll_next_chunk(cx)).await {
|
||||||
Some(result) => {
|
Some(result) => {
|
||||||
if !wrt.encode(h1::Message::Chunk(Some(result?)), codec)? {
|
if !io.encode(h1::Message::Chunk(Some(result?)), codec)? {
|
||||||
wrt.write_ready(false).await?;
|
io.write_ready(false).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
wrt.encode(h1::Message::Chunk(None), codec)?;
|
io.encode(h1::Message::Chunk(None), codec)?;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
wrt.write_ready(true).await?;
|
io.write_ready(true).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -174,8 +173,7 @@ impl Stream for PlStream {
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Self::Item>> {
|
) -> Poll<Option<Self::Item>> {
|
||||||
let mut this = self.as_mut();
|
let mut this = self.as_mut();
|
||||||
|
match this.io.as_ref().unwrap().poll_read_next(&this.codec, cx)? {
|
||||||
match this.io.as_ref().unwrap().poll_next(&this.codec, cx)? {
|
|
||||||
Poll::Pending => Poll::Pending,
|
Poll::Pending => Poll::Pending,
|
||||||
Poll::Ready(Some(chunk)) => {
|
Poll::Ready(Some(chunk)) => {
|
||||||
if let Some(chunk) = chunk {
|
if let Some(chunk) = chunk {
|
||||||
|
@ -198,7 +196,7 @@ fn release_connection(
|
||||||
created: Instant,
|
created: Instant,
|
||||||
mut pool: Option<Acquired>,
|
mut pool: Option<Acquired>,
|
||||||
) {
|
) {
|
||||||
if force_close || io.is_closed() || io.read().with_buf(|buf| !buf.is_empty()) {
|
if force_close || io.is_closed() || io.with_read_buf(|buf| !buf.is_empty()) {
|
||||||
if let Some(mut pool) = pool.take() {
|
if let Some(mut pool) = pool.take() {
|
||||||
pool.close(Connection::new(ConnectionType::H1(io), created, None));
|
pool.close(Connection::new(ConnectionType::H1(io), created, None));
|
||||||
}
|
}
|
||||||
|
|
|
@ -245,7 +245,7 @@ impl Inner {
|
||||||
if s.is_closed() {
|
if s.is_closed() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let is_valid = s.read().with_buf(|buf| {
|
let is_valid = s.with_read_buf(|buf| {
|
||||||
if buf.is_empty() || (buf.len() == 2 && &buf[..] == b"\r\n")
|
if buf.is_empty() || (buf.len() == 2 && &buf[..] == b"\r\n")
|
||||||
{
|
{
|
||||||
buf.clear();
|
buf.clear();
|
||||||
|
|
|
@ -166,10 +166,9 @@ pub enum DispatchError {
|
||||||
/// Upgrade service error
|
/// Upgrade service error
|
||||||
Upgrade(Box<dyn std::error::Error>),
|
Upgrade(Box<dyn std::error::Error>),
|
||||||
|
|
||||||
/// An `io::Error` that occurred while trying to read or write to a network
|
/// Peer is disconnected, error indicates that peer is disconnected because of it
|
||||||
/// stream.
|
#[display(fmt = "Disconnect: {:?}", _0)]
|
||||||
#[display(fmt = "IO error: {}", _0)]
|
Disconnect(Option<io::Error>),
|
||||||
Io(io::Error),
|
|
||||||
|
|
||||||
/// Http request parse error.
|
/// Http request parse error.
|
||||||
#[display(fmt = "Parse error: {}", _0)]
|
#[display(fmt = "Parse error: {}", _0)]
|
||||||
|
@ -215,6 +214,12 @@ pub enum DispatchError {
|
||||||
|
|
||||||
impl std::error::Error for DispatchError {}
|
impl std::error::Error for DispatchError {}
|
||||||
|
|
||||||
|
impl From<io::Error> for DispatchError {
|
||||||
|
fn from(err: io::Error) -> Self {
|
||||||
|
DispatchError::Disconnect(Some(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// A set of error that can occure during parsing content type
|
/// A set of error that can occure during parsing content type
|
||||||
#[derive(PartialEq, Debug, Display)]
|
#[derive(PartialEq, Debug, Display)]
|
||||||
pub enum ContentTypeError {
|
pub enum ContentTypeError {
|
||||||
|
|
|
@ -4,7 +4,7 @@ use std::{error::Error, fmt, future::Future, marker, pin::Pin, rc::Rc, time};
|
||||||
|
|
||||||
use crate::io::{Filter, Io, IoRef};
|
use crate::io::{Filter, Io, IoRef};
|
||||||
use crate::service::Service;
|
use crate::service::Service;
|
||||||
use crate::{time::now, util::Bytes};
|
use crate::{time::now, util::ready, util::Bytes, util::Either};
|
||||||
|
|
||||||
use crate::http;
|
use crate::http;
|
||||||
use crate::http::body::{BodySize, MessageBody, ResponseBody};
|
use crate::http::body::{BodySize, MessageBody, ResponseBody};
|
||||||
|
@ -76,14 +76,6 @@ struct DispatcherInner<F, S, B, X, U> {
|
||||||
_t: marker::PhantomData<(S, B)>,
|
_t: marker::PhantomData<(S, B)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
|
||||||
enum ReadPayloadStatus {
|
|
||||||
Done,
|
|
||||||
Updated,
|
|
||||||
Pending,
|
|
||||||
Dropped,
|
|
||||||
}
|
|
||||||
|
|
||||||
enum WritePayloadStatus<B> {
|
enum WritePayloadStatus<B> {
|
||||||
Next(State<B>),
|
Next(State<B>),
|
||||||
Pause,
|
Pause,
|
||||||
|
@ -171,9 +163,15 @@ where
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
// we might need to read more data into a request payload
|
// we might need to read more data into a request payload
|
||||||
// (ie service future can wait for payload data)
|
// (ie service future can wait for payload data)
|
||||||
if this.inner.poll_read_payload(cx)
|
if this.inner.payload.is_some() {
|
||||||
!= ReadPayloadStatus::Updated
|
if let Err(e) =
|
||||||
{
|
ready!(this.inner.poll_read_payload(cx))
|
||||||
|
{
|
||||||
|
*this.st = State::Stop;
|
||||||
|
this.inner.unregister_keepalive();
|
||||||
|
this.inner.error = Some(e);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
return Poll::Pending;
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -185,7 +183,7 @@ where
|
||||||
Poll::Ready(result) => match result {
|
Poll::Ready(result) => match result {
|
||||||
Ok(req) => {
|
Ok(req) => {
|
||||||
let result =
|
let result =
|
||||||
this.inner.state.write().with_buf(|buf| {
|
this.inner.state.with_write_buf(|buf| {
|
||||||
buf.extend_from_slice(
|
buf.extend_from_slice(
|
||||||
b"HTTP/1.1 100 Continue\r\n\r\n",
|
b"HTTP/1.1 100 Continue\r\n\r\n",
|
||||||
)
|
)
|
||||||
|
@ -262,7 +260,7 @@ where
|
||||||
log::trace!("trying to read http message");
|
log::trace!("trying to read http message");
|
||||||
|
|
||||||
// stop dispatcher
|
// stop dispatcher
|
||||||
if this.inner.state.is_dispatcher_stopped() {
|
if this.inner.io().is_dispatcher_stopped() {
|
||||||
log::trace!("dispatcher is instructed to stop");
|
log::trace!("dispatcher is instructed to stop");
|
||||||
*this.st = State::Stop;
|
*this.st = State::Stop;
|
||||||
this.inner.unregister_keepalive();
|
this.inner.unregister_keepalive();
|
||||||
|
@ -285,165 +283,136 @@ where
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let read = this.inner.state.read();
|
let io = this.inner.io();
|
||||||
|
|
||||||
// decode incoming bytes stream
|
// decode incoming bytes stream
|
||||||
if read.is_ready() {
|
match io.poll_read_next(&this.inner.codec, cx) {
|
||||||
match read.decode(&this.inner.codec) {
|
Poll::Ready(Some(Ok((mut req, pl)))) => {
|
||||||
Ok(Some((mut req, pl))) => {
|
log::trace!(
|
||||||
log::trace!(
|
"http message is received: {:?} and payload {:?}",
|
||||||
"http message is received: {:?} and payload {:?}",
|
req,
|
||||||
req,
|
pl
|
||||||
pl
|
);
|
||||||
);
|
req.head_mut().io = Some(io.get_ref());
|
||||||
req.head_mut().io = Some(this.inner.state.clone());
|
|
||||||
|
|
||||||
// configure request payload
|
// configure request payload
|
||||||
let upgrade = match pl {
|
let upgrade = match pl {
|
||||||
PayloadType::None => false,
|
PayloadType::None => false,
|
||||||
PayloadType::Payload(decoder) => {
|
PayloadType::Payload(decoder) => {
|
||||||
|
let (ps, pl) = Payload::create(false);
|
||||||
|
req.replace_payload(http::Payload::H1(pl));
|
||||||
|
this.inner.payload = Some((decoder, ps));
|
||||||
|
false
|
||||||
|
}
|
||||||
|
PayloadType::Stream(decoder) => {
|
||||||
|
if this.inner.config.upgrade.is_none() {
|
||||||
let (ps, pl) = Payload::create(false);
|
let (ps, pl) = Payload::create(false);
|
||||||
req.replace_payload(http::Payload::H1(pl));
|
req.replace_payload(http::Payload::H1(pl));
|
||||||
this.inner.payload = Some((decoder, ps));
|
this.inner.payload = Some((decoder, ps));
|
||||||
false
|
false
|
||||||
|
} else {
|
||||||
|
this.inner.flags.insert(Flags::UPGRADE);
|
||||||
|
true
|
||||||
}
|
}
|
||||||
PayloadType::Stream(decoder) => {
|
}
|
||||||
if this.inner.config.upgrade.is_none() {
|
};
|
||||||
let (ps, pl) = Payload::create(false);
|
|
||||||
req.replace_payload(http::Payload::H1(pl));
|
// unregister slow-request timer
|
||||||
this.inner.payload = Some((decoder, ps));
|
if !this.inner.flags.contains(Flags::STARTED) {
|
||||||
false
|
this.inner.flags.insert(Flags::STARTED);
|
||||||
} else {
|
this.inner
|
||||||
this.inner.flags.insert(Flags::UPGRADE);
|
.config
|
||||||
true
|
.timer_h1
|
||||||
|
.unregister(this.inner.expire, &this.inner.state);
|
||||||
|
}
|
||||||
|
|
||||||
|
if upgrade {
|
||||||
|
// Handle UPGRADE request
|
||||||
|
log::trace!("prep io for upgrade handler");
|
||||||
|
*this.st = State::Upgrade(Some(req));
|
||||||
|
} else {
|
||||||
|
*this.st = State::Call;
|
||||||
|
this.call.set(
|
||||||
|
if let Some(ref f) = this.inner.config.on_request {
|
||||||
|
// Handle filter fut
|
||||||
|
CallState::Filter {
|
||||||
|
fut: f.call((req, this.inner.state.clone())),
|
||||||
}
|
}
|
||||||
}
|
} else if req.head().expect() {
|
||||||
};
|
// Handle normal requests with EXPECT: 100-Continue` header
|
||||||
|
CallState::Expect {
|
||||||
// unregister slow-request timer
|
fut: this.inner.config.expect.call(req),
|
||||||
if !this.inner.flags.contains(Flags::STARTED) {
|
}
|
||||||
this.inner.flags.insert(Flags::STARTED);
|
} else {
|
||||||
this.inner.config.timer_h1.unregister(
|
// Handle normal requests
|
||||||
this.inner.expire,
|
CallState::Service {
|
||||||
&this.inner.state,
|
fut: this.inner.config.service.call(req),
|
||||||
);
|
}
|
||||||
}
|
},
|
||||||
|
);
|
||||||
if upgrade {
|
|
||||||
// Handle UPGRADE request
|
|
||||||
log::trace!("prep io for upgrade handler");
|
|
||||||
*this.st = State::Upgrade(Some(req));
|
|
||||||
} else {
|
|
||||||
*this.st = State::Call;
|
|
||||||
this.call.set(
|
|
||||||
if let Some(ref f) = this.inner.config.on_request
|
|
||||||
{
|
|
||||||
// Handle filter fut
|
|
||||||
CallState::Filter {
|
|
||||||
fut: f.call((
|
|
||||||
req,
|
|
||||||
this.inner.state.clone(),
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
} else if req.head().expect() {
|
|
||||||
// Handle normal requests with EXPECT: 100-Continue` header
|
|
||||||
CallState::Expect {
|
|
||||||
fut: this.inner.config.expect.call(req),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Handle normal requests
|
|
||||||
CallState::Service {
|
|
||||||
fut: this.inner.config.service.call(req),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(None) => {
|
|
||||||
log::trace!("not enough data to decode http message");
|
|
||||||
|
|
||||||
// if io error occured or connection is not keep-alive
|
|
||||||
// then disconnect
|
|
||||||
if this.inner.flags.contains(Flags::STARTED)
|
|
||||||
&& (!this.inner.flags.contains(Flags::KEEPALIVE)
|
|
||||||
|| !this.inner.codec.keepalive_enabled()
|
|
||||||
|| !this.inner.state.is_io_open())
|
|
||||||
{
|
|
||||||
*this.st = State::Stop;
|
|
||||||
this.inner.unregister_keepalive();
|
|
||||||
this.inner.state.stop_dispatcher();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let _ = read.poll_read_ready(cx);
|
|
||||||
return Poll::Pending;
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
// Malformed requests, respond with 400
|
|
||||||
log::trace!("malformed request: {:?}", err);
|
|
||||||
let (res, body) =
|
|
||||||
Response::BadRequest().finish().into_parts();
|
|
||||||
this.inner.error = Some(DispatchError::Parse(err));
|
|
||||||
*this.st =
|
|
||||||
this.inner.send_response(res, body.into_body());
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
Poll::Ready(Some(Err(Either::Left(err)))) => {
|
||||||
// if connection is not keep-alive then disconnect
|
// Malformed requests, respond with 400
|
||||||
if this.inner.flags.contains(Flags::STARTED)
|
log::trace!("malformed request: {:?}", err);
|
||||||
&& !this.inner.flags.contains(Flags::KEEPALIVE)
|
let (res, body) =
|
||||||
{
|
Response::BadRequest().finish().into_parts();
|
||||||
|
this.inner.error = Some(DispatchError::Parse(err));
|
||||||
|
*this.st = this.inner.send_response(res, body.into_body());
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Err(Either::Right(err)))) => {
|
||||||
|
log::trace!("peer is gone with {:?}", err);
|
||||||
|
// peer is gone
|
||||||
*this.st = State::Stop;
|
*this.st = State::Stop;
|
||||||
this.inner.unregister_keepalive();
|
this.inner.unregister_keepalive();
|
||||||
continue;
|
this.inner.state.stop_dispatcher();
|
||||||
|
return Poll::Ready(Err(DispatchError::Disconnect(Some(
|
||||||
|
err,
|
||||||
|
))));
|
||||||
|
}
|
||||||
|
Poll::Ready(None) => {
|
||||||
|
log::trace!("peer is gone");
|
||||||
|
// peer is gone
|
||||||
|
this.inner.unregister_keepalive();
|
||||||
|
this.inner.state.stop_dispatcher();
|
||||||
|
return Poll::Ready(Err(DispatchError::Disconnect(None)));
|
||||||
|
}
|
||||||
|
Poll::Pending => {
|
||||||
|
log::trace!("not enough data to decode http message");
|
||||||
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
let _ = read.poll_read_ready(cx);
|
|
||||||
return Poll::Pending;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// consume request's payload
|
// consume request's payload
|
||||||
State::ReadPayload => {
|
State::ReadPayload => match ready!(this.inner.poll_read_payload(cx)) {
|
||||||
if !this.inner.state.is_io_open() {
|
Ok(()) => {
|
||||||
*this.st = State::Stop;
|
*this.st = this.inner.switch_to_read_request();
|
||||||
this.inner.unregister_keepalive();
|
|
||||||
} else {
|
|
||||||
loop {
|
|
||||||
match this.inner.poll_read_payload(cx) {
|
|
||||||
ReadPayloadStatus::Updated => continue,
|
|
||||||
ReadPayloadStatus::Pending => return Poll::Pending,
|
|
||||||
ReadPayloadStatus::Done => {
|
|
||||||
*this.st = {
|
|
||||||
this.inner.reset_keepalive();
|
|
||||||
State::ReadRequest
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ReadPayloadStatus::Dropped => {
|
|
||||||
*this.st = State::Stop;
|
|
||||||
this.inner.unregister_keepalive();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
Err(e) => {
|
||||||
|
*this.st = State::Stop;
|
||||||
|
this.inner.error = Some(e);
|
||||||
|
this.inner.unregister_keepalive();
|
||||||
|
}
|
||||||
|
},
|
||||||
// send response body
|
// send response body
|
||||||
State::SendPayload { ref mut body } => {
|
State::SendPayload { ref mut body } => {
|
||||||
if !this.inner.state.is_io_open() {
|
if !this.inner.state.is_io_open() {
|
||||||
*this.st = State::Stop;
|
*this.st = State::Stop;
|
||||||
|
this.inner.error = Some(this.inner.state.take_error().into());
|
||||||
|
this.inner.unregister_keepalive();
|
||||||
|
} else if let Poll::Ready(Err(e)) = this.inner.poll_read_payload(cx)
|
||||||
|
{
|
||||||
|
*this.st = State::Stop;
|
||||||
|
this.inner.error = Some(e);
|
||||||
this.inner.unregister_keepalive();
|
this.inner.unregister_keepalive();
|
||||||
} else {
|
} else {
|
||||||
this.inner.poll_read_payload(cx);
|
|
||||||
|
|
||||||
match body.poll_next_chunk(cx) {
|
match body.poll_next_chunk(cx) {
|
||||||
Poll::Ready(item) => match this.inner.send_payload(item) {
|
Poll::Ready(item) => match this.inner.send_payload(item) {
|
||||||
WritePayloadStatus::Next(st) => {
|
WritePayloadStatus::Next(st) => {
|
||||||
*this.st = st;
|
*this.st = st;
|
||||||
}
|
}
|
||||||
WritePayloadStatus::Pause => {
|
WritePayloadStatus::Pause => {
|
||||||
this.inner
|
this.inner.io().enable_write_backpressure(cx);
|
||||||
.state
|
|
||||||
.write()
|
|
||||||
.enable_backpressure(Some(cx));
|
|
||||||
return Poll::Pending;
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
WritePayloadStatus::Continue => (),
|
WritePayloadStatus::Continue => (),
|
||||||
|
@ -454,7 +423,7 @@ where
|
||||||
}
|
}
|
||||||
// stop io tasks and call upgrade service
|
// stop io tasks and call upgrade service
|
||||||
State::Upgrade(ref mut req) => {
|
State::Upgrade(ref mut req) => {
|
||||||
log::trace!("initate upgrade handling");
|
log::trace!("switching to upgrade service");
|
||||||
|
|
||||||
let io = this.inner.io.take().unwrap();
|
let io = this.inner.io.take().unwrap();
|
||||||
let req = req.take().unwrap();
|
let req = req.take().unwrap();
|
||||||
|
@ -482,8 +451,9 @@ where
|
||||||
{
|
{
|
||||||
// get io error
|
// get io error
|
||||||
if this.inner.error.is_none() {
|
if this.inner.error.is_none() {
|
||||||
this.inner.error =
|
this.inner.error = Some(DispatchError::Disconnect(
|
||||||
this.inner.state.take_error().map(DispatchError::Io);
|
this.inner.state.take_error(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Poll::Ready(
|
return Poll::Ready(
|
||||||
|
@ -509,6 +479,22 @@ where
|
||||||
S::Response: Into<Response<B>>,
|
S::Response: Into<Response<B>>,
|
||||||
B: MessageBody,
|
B: MessageBody,
|
||||||
{
|
{
|
||||||
|
fn io(&self) -> &Io<T> {
|
||||||
|
self.io.as_ref().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn switch_to_read_request(&mut self) -> State<B> {
|
||||||
|
// connection is not keep-alive, disconnect
|
||||||
|
if !self.flags.contains(Flags::KEEPALIVE) || !self.codec.keepalive_enabled() {
|
||||||
|
self.unregister_keepalive();
|
||||||
|
self.state.stop_dispatcher();
|
||||||
|
State::Stop
|
||||||
|
} else {
|
||||||
|
self.reset_keepalive();
|
||||||
|
State::ReadRequest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn unregister_keepalive(&mut self) {
|
fn unregister_keepalive(&mut self) {
|
||||||
if self.flags.contains(Flags::KEEPALIVE) {
|
if self.flags.contains(Flags::KEEPALIVE) {
|
||||||
self.config.timer_h1.unregister(self.expire, &self.state);
|
self.config.timer_h1.unregister(self.expire, &self.state);
|
||||||
|
@ -524,7 +510,7 @@ where
|
||||||
.timer_h1
|
.timer_h1
|
||||||
.register(expire, self.expire, &self.state);
|
.register(expire, self.expire, &self.state);
|
||||||
self.expire = expire;
|
self.expire = expire;
|
||||||
self.state.reset_keepalive();
|
self.io().reset_keepalive();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -558,8 +544,7 @@ where
|
||||||
// so we skip response processing for droppped connection
|
// so we skip response processing for droppped connection
|
||||||
if self.state.is_io_open() {
|
if self.state.is_io_open() {
|
||||||
let result = self
|
let result = self
|
||||||
.state
|
.io()
|
||||||
.write()
|
|
||||||
.encode(Message::Item((msg, body.size())), &self.codec)
|
.encode(Message::Item((msg, body.size())), &self.codec)
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
if let Some(mut payload) = self.payload.take() {
|
if let Some(mut payload) = self.payload.take() {
|
||||||
|
@ -580,8 +565,7 @@ where
|
||||||
} else if self.payload.is_some() {
|
} else if self.payload.is_some() {
|
||||||
State::ReadPayload
|
State::ReadPayload
|
||||||
} else {
|
} else {
|
||||||
self.reset_keepalive();
|
self.switch_to_read_request()
|
||||||
State::ReadRequest
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => State::SendPayload { body },
|
_ => State::SendPayload { body },
|
||||||
|
@ -599,11 +583,7 @@ where
|
||||||
match item {
|
match item {
|
||||||
Some(Ok(item)) => {
|
Some(Ok(item)) => {
|
||||||
trace!("got response chunk: {:?}", item.len());
|
trace!("got response chunk: {:?}", item.len());
|
||||||
match self
|
match self.io().encode(Message::Chunk(Some(item)), &self.codec) {
|
||||||
.state
|
|
||||||
.write()
|
|
||||||
.encode(Message::Chunk(Some(item)), &self.codec)
|
|
||||||
{
|
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
self.error = Some(DispatchError::Encode(err));
|
self.error = Some(DispatchError::Encode(err));
|
||||||
WritePayloadStatus::Next(State::Stop)
|
WritePayloadStatus::Next(State::Stop)
|
||||||
|
@ -619,9 +599,7 @@ where
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
trace!("response payload eof");
|
trace!("response payload eof");
|
||||||
if let Err(err) =
|
if let Err(err) = self.io().encode(Message::Chunk(None), &self.codec) {
|
||||||
self.state.write().encode(Message::Chunk(None), &self.codec)
|
|
||||||
{
|
|
||||||
self.error = Some(DispatchError::Encode(err));
|
self.error = Some(DispatchError::Encode(err));
|
||||||
WritePayloadStatus::Next(State::Stop)
|
WritePayloadStatus::Next(State::Stop)
|
||||||
} else if self.flags.contains(Flags::SENDPAYLOAD_AND_STOP) {
|
} else if self.flags.contains(Flags::SENDPAYLOAD_AND_STOP) {
|
||||||
|
@ -630,7 +608,7 @@ where
|
||||||
WritePayloadStatus::Next(State::ReadPayload)
|
WritePayloadStatus::Next(State::ReadPayload)
|
||||||
} else {
|
} else {
|
||||||
self.reset_keepalive();
|
self.reset_keepalive();
|
||||||
WritePayloadStatus::Next(State::ReadRequest)
|
WritePayloadStatus::Next(self.switch_to_read_request())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(Err(e)) => {
|
Some(Err(e)) => {
|
||||||
|
@ -642,68 +620,67 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process request's payload
|
/// Process request's payload
|
||||||
fn poll_read_payload(&mut self, cx: &mut Context<'_>) -> ReadPayloadStatus {
|
fn poll_read_payload(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), DispatchError>> {
|
||||||
// check if payload data is required
|
// check if payload data is required
|
||||||
if let Some(ref mut payload) = self.payload {
|
if let Some(ref mut payload) = self.payload {
|
||||||
match payload.1.poll_data_required(cx) {
|
match payload.1.poll_data_required(cx) {
|
||||||
PayloadStatus::Read => {
|
PayloadStatus::Read => {
|
||||||
let read = self.state.read();
|
let io = self.io.as_ref().unwrap();
|
||||||
|
|
||||||
// read request payload
|
// read request payload
|
||||||
let mut updated = false;
|
let mut updated = false;
|
||||||
loop {
|
loop {
|
||||||
let item = read.decode(&payload.0);
|
match io.poll_read_next(&payload.0, cx) {
|
||||||
match item {
|
Poll::Ready(Some(Ok(PayloadItem::Chunk(chunk)))) => {
|
||||||
Ok(Some(PayloadItem::Chunk(chunk))) => {
|
|
||||||
updated = true;
|
updated = true;
|
||||||
payload.1.feed_data(chunk);
|
payload.1.feed_data(chunk);
|
||||||
}
|
}
|
||||||
Ok(Some(PayloadItem::Eof)) => {
|
Poll::Ready(Some(Ok(PayloadItem::Eof))) => {
|
||||||
payload.1.feed_eof();
|
payload.1.feed_eof();
|
||||||
self.payload = None;
|
self.payload = None;
|
||||||
if !updated {
|
if !updated {
|
||||||
return ReadPayloadStatus::Done;
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
let _ = read.poll_read_ready(cx);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Ok(None) => {
|
Poll::Ready(None) => {
|
||||||
if !self.state.is_io_open() {
|
|
||||||
payload.1.set_error(PayloadError::EncodingCorrupted);
|
|
||||||
self.payload = None;
|
|
||||||
self.error = Some(ParseError::Incomplete.into());
|
|
||||||
return ReadPayloadStatus::Dropped;
|
|
||||||
} else {
|
|
||||||
let _ = read.poll_read_ready(cx);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
payload.1.set_error(PayloadError::EncodingCorrupted);
|
payload.1.set_error(PayloadError::EncodingCorrupted);
|
||||||
self.payload = None;
|
self.payload = None;
|
||||||
self.error = Some(DispatchError::Parse(e));
|
return Poll::Ready(Err(ParseError::Incomplete.into()));
|
||||||
return ReadPayloadStatus::Dropped;
|
|
||||||
}
|
}
|
||||||
|
Poll::Ready(Some(Err(e))) => {
|
||||||
|
payload.1.set_error(PayloadError::EncodingCorrupted);
|
||||||
|
self.payload = None;
|
||||||
|
return Poll::Ready(Err(match e {
|
||||||
|
Either::Left(e) => DispatchError::Parse(e),
|
||||||
|
Either::Right(e) => {
|
||||||
|
DispatchError::Disconnect(Some(e))
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Poll::Pending => break,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if updated {
|
if updated {
|
||||||
ReadPayloadStatus::Updated
|
Poll::Ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
ReadPayloadStatus::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
PayloadStatus::Pause => ReadPayloadStatus::Pending,
|
PayloadStatus::Pause => Poll::Pending,
|
||||||
PayloadStatus::Dropped => {
|
PayloadStatus::Dropped => {
|
||||||
// service call is not interested in payload
|
// service call is not interested in payload
|
||||||
// wait until future completes and then close
|
// wait until future completes and then close
|
||||||
// connection
|
// connection
|
||||||
self.payload = None;
|
self.payload = None;
|
||||||
self.error = Some(DispatchError::PayloadIsNotConsumed);
|
Poll::Ready(Err(DispatchError::PayloadIsNotConsumed))
|
||||||
ReadPayloadStatus::Dropped
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ReadPayloadStatus::Done
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -852,7 +829,6 @@ mod tests {
|
||||||
|
|
||||||
client.close().await;
|
client.close().await;
|
||||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
|
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
|
||||||
// assert!(h1.inner.flags.contains(Flags::SHUTDOWN_IO));
|
|
||||||
assert!(!h1.inner.state.is_io_open());
|
assert!(!h1.inner.state.is_io_open());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -865,14 +841,14 @@ mod tests {
|
||||||
Ok::<_, io::Error>(Response::Ok().finish())
|
Ok::<_, io::Error>(Response::Ok().finish())
|
||||||
});
|
});
|
||||||
|
|
||||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
client.write("GET /test1 HTTP/1.1\r\n\r\n");
|
||||||
|
|
||||||
let mut buf = client.read().await.unwrap();
|
let mut buf = client.read().await.unwrap();
|
||||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||||
assert!(!client.is_server_dropped());
|
assert!(!client.is_server_dropped());
|
||||||
|
|
||||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
client.write("GET /test2 HTTP/1.1\r\n\r\n");
|
||||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
client.write("GET /test3 HTTP/1.1\r\n\r\n");
|
||||||
|
|
||||||
let mut buf = client.read().await.unwrap();
|
let mut buf = client.read().await.unwrap();
|
||||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||||
|
@ -896,7 +872,7 @@ mod tests {
|
||||||
Ok::<_, io::Error>(Response::Ok().finish())
|
Ok::<_, io::Error>(Response::Ok().finish())
|
||||||
});
|
});
|
||||||
|
|
||||||
client.write("GET /test HTTP/1.1\r\ncontent-length: 5\r\n\r\n");
|
client.write("GET /test1 HTTP/1.1\r\ncontent-length: 5\r\n\r\n");
|
||||||
sleep(Millis(50)).await;
|
sleep(Millis(50)).await;
|
||||||
client.write("xxxxx");
|
client.write("xxxxx");
|
||||||
|
|
||||||
|
@ -904,7 +880,7 @@ mod tests {
|
||||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||||
assert!(!client.is_server_dropped());
|
assert!(!client.is_server_dropped());
|
||||||
|
|
||||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
client.write("GET /test2 HTTP/1.1\r\n\r\n");
|
||||||
|
|
||||||
let mut buf = client.read().await.unwrap();
|
let mut buf = client.read().await.unwrap();
|
||||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||||
|
@ -1103,11 +1079,11 @@ mod tests {
|
||||||
assert_eq!(num.load(Ordering::Relaxed), 65_536);
|
assert_eq!(num.load(Ordering::Relaxed), 65_536);
|
||||||
|
|
||||||
// response message + chunking encoding
|
// response message + chunking encoding
|
||||||
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 65629);
|
assert_eq!(state.with_write_buf(|buf| buf.len()).unwrap(), 65629);
|
||||||
|
|
||||||
client.remote_buffer_cap(65536);
|
client.remote_buffer_cap(65536);
|
||||||
sleep(Millis(50)).await;
|
sleep(Millis(50)).await;
|
||||||
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 93);
|
assert_eq!(state.with_write_buf(|buf| buf.len()).unwrap(), 93);
|
||||||
|
|
||||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
|
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
|
||||||
assert_eq!(num.load(Ordering::Relaxed), 65_536 * 2);
|
assert_eq!(num.load(Ordering::Relaxed), 65_536 * 2);
|
||||||
|
|
|
@ -11,7 +11,7 @@ pub mod variant;
|
||||||
pub use self::extensions::Extensions;
|
pub use self::extensions::Extensions;
|
||||||
|
|
||||||
pub use ntex_bytes::{Buf, BufMut, ByteString, Bytes, BytesMut, Pool, PoolId, PoolRef};
|
pub use ntex_bytes::{Buf, BufMut, ByteString, Bytes, BytesMut, Pool, PoolId, PoolRef};
|
||||||
pub use ntex_util::future::*;
|
pub use ntex_util::{future::*, ready};
|
||||||
|
|
||||||
pub type HashMap<K, V> = std::collections::HashMap<K, V, fxhash::FxBuildHasher>;
|
pub type HashMap<K, V> = std::collections::HashMap<K, V, fxhash::FxBuildHasher>;
|
||||||
pub type HashSet<V> = std::collections::HashSet<V, fxhash::FxBuildHasher>;
|
pub type HashSet<V> = std::collections::HashSet<V, fxhash::FxBuildHasher>;
|
||||||
|
|
|
@ -28,9 +28,7 @@ impl WsSink {
|
||||||
_ => false,
|
_ => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let wrt = inner.io.write();
|
inner.io.encode(item, &inner.codec)?;
|
||||||
wrt.write_ready(false).await?;
|
|
||||||
wrt.encode(item, &inner.codec)?;
|
|
||||||
if close {
|
if close {
|
||||||
inner.io.close();
|
inner.io.close();
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,12 +33,11 @@ async fn test_simple() {
|
||||||
let res = handshake_response(req.head()).finish();
|
let res = handshake_response(req.head()).finish();
|
||||||
|
|
||||||
// send handshake respone
|
// send handshake respone
|
||||||
io.write()
|
io.encode(
|
||||||
.encode(
|
h1::Message::Item((res.drop_body(), BodySize::None)),
|
||||||
h1::Message::Item((res.drop_body(), BodySize::None)),
|
&codec,
|
||||||
&codec,
|
)
|
||||||
)
|
.unwrap();
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// start websocket service
|
// start websocket service
|
||||||
Dispatcher::new(
|
Dispatcher::new(
|
||||||
|
|
|
@ -46,8 +46,7 @@ impl Service for WsService {
|
||||||
let fut = async move {
|
let fut = async move {
|
||||||
let res = handshake(req.head()).unwrap().message_body(());
|
let res = handshake(req.head()).unwrap().message_body(());
|
||||||
|
|
||||||
io.write()
|
io.encode((res, body::BodySize::None).into(), &codec)
|
||||||
.encode((res, body::BodySize::None).into(), &codec)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
Dispatcher::new(io.into_boxed(), ws::Codec::new(), service, Timer::default())
|
Dispatcher::new(io.into_boxed(), ws::Codec::new(), service, Timer::default())
|
||||||
|
|
|
@ -69,7 +69,6 @@ async fn web_ws() {
|
||||||
|
|
||||||
#[ntex::test]
|
#[ntex::test]
|
||||||
async fn web_ws_client() {
|
async fn web_ws_client() {
|
||||||
env_logger::init();
|
|
||||||
let srv = test::server(|| {
|
let srv = test::server(|| {
|
||||||
App::new().service(web::resource("/").route(web::to(
|
App::new().service(web::resource("/").route(web::to(
|
||||||
|req: HttpRequest, pl: web::types::Payload| async move {
|
|req: HttpRequest, pl: web::types::Payload| async move {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue