mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-04 13:27:39 +03:00
Refactor write back-pressure (#39)
* refactor error handling * refactor write back-pressure
This commit is contained in:
parent
f0fe2bbc59
commit
26543a4247
21 changed files with 567 additions and 899 deletions
|
@ -1,5 +1,11 @@
|
|||
# Changes
|
||||
|
||||
## [0.2.0-b.7] - 2021-01-25
|
||||
|
||||
* Fix error handling for framed disaptcher
|
||||
|
||||
* Refactor framed disaptcher write back-pressure support
|
||||
|
||||
## [0.2.0-b.6] - 2021-01-24
|
||||
|
||||
* http: Pass io stream to upgrade handler
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex"
|
||||
version = "0.2.0-b.6"
|
||||
version = "0.2.0-b.7"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "Framework for composable network services"
|
||||
readme = "README.md"
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
//! Framed transport dispatcher
|
||||
use std::task::{Context, Poll};
|
||||
use std::{
|
||||
cell::Cell, cell::RefCell, fmt, future::Future, pin::Pin, rc::Rc, time::Duration,
|
||||
time::Instant,
|
||||
};
|
||||
use std::{cell::Cell, cell::RefCell, pin::Pin, rc::Rc, time::Duration, time::Instant};
|
||||
|
||||
use either::Either;
|
||||
use futures::FutureExt;
|
||||
use futures::{ready, Future, FutureExt};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
|
||||
use crate::framed::{DispatchItem, ReadTask, State, Timer, WriteTask};
|
||||
|
@ -29,7 +26,7 @@ pin_project_lite::pin_project! {
|
|||
service: S,
|
||||
inner: DispatcherInner<S, U>,
|
||||
#[pin]
|
||||
response: Option<S::Future>,
|
||||
fut: Option<S::Future>,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,11 +35,12 @@ where
|
|||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
|
||||
U: Encoder + Decoder,
|
||||
{
|
||||
st: DispatcherState,
|
||||
st: Cell<DispatcherState>,
|
||||
state: State,
|
||||
timer: Timer,
|
||||
updated: Instant,
|
||||
keepalive_timeout: u16,
|
||||
ka_timeout: u16,
|
||||
ka_updated: Cell<Instant>,
|
||||
error: Cell<Option<S::Error>>,
|
||||
shared: Rc<DispatcherShared<S, U>>,
|
||||
}
|
||||
|
||||
|
@ -59,16 +57,24 @@ where
|
|||
#[derive(Copy, Clone, Debug)]
|
||||
enum DispatcherState {
|
||||
Processing,
|
||||
WrEnabled,
|
||||
WrWaitReady,
|
||||
Stop,
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
pub(crate) enum DispatcherError<S, U> {
|
||||
enum DispatcherError<S, U> {
|
||||
KeepAlive,
|
||||
Encoder(U),
|
||||
Service(S),
|
||||
}
|
||||
|
||||
enum PollService<U: Encoder + Decoder> {
|
||||
Item(DispatchItem<U>),
|
||||
ServiceError,
|
||||
Ready,
|
||||
}
|
||||
|
||||
impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
|
||||
fn from(err: Either<S, U>) -> Self {
|
||||
match err {
|
||||
|
@ -78,19 +84,6 @@ impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<E1, E2: fmt::Debug> DispatcherError<E1, E2> {
|
||||
fn convert<U>(self) -> Option<DispatchItem<U>>
|
||||
where
|
||||
U: Encoder<Error = E2> + Decoder,
|
||||
{
|
||||
match self {
|
||||
DispatcherError::KeepAlive => Some(DispatchItem::KeepAliveTimeout),
|
||||
DispatcherError::Encoder(err) => Some(DispatchItem::EncoderError(err)),
|
||||
DispatcherError::Service(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, U> Dispatcher<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
|
||||
|
@ -125,21 +118,22 @@ where
|
|||
timer: Timer,
|
||||
) -> Self {
|
||||
let updated = timer.now();
|
||||
let keepalive_timeout: u16 = 30;
|
||||
let ka_timeout: u16 = 30;
|
||||
|
||||
// register keepalive timer
|
||||
let expire = updated + Duration::from_secs(keepalive_timeout as u64);
|
||||
let expire = updated + Duration::from_secs(ka_timeout as u64);
|
||||
timer.register(expire, expire, &state);
|
||||
|
||||
Dispatcher {
|
||||
service: service.into_service(),
|
||||
response: None,
|
||||
fut: None,
|
||||
inner: DispatcherInner {
|
||||
state,
|
||||
timer,
|
||||
updated,
|
||||
keepalive_timeout,
|
||||
st: DispatcherState::Processing,
|
||||
ka_timeout,
|
||||
ka_updated: Cell::new(updated),
|
||||
error: Cell::new(None),
|
||||
st: Cell::new(DispatcherState::Processing),
|
||||
shared: Rc::new(DispatcherShared {
|
||||
codec,
|
||||
error: Cell::new(None),
|
||||
|
@ -156,15 +150,15 @@ where
|
|||
/// By default keep-alive timeout is set to 30 seconds.
|
||||
pub fn keepalive_timeout(mut self, timeout: u16) -> Self {
|
||||
// register keepalive timer
|
||||
let prev = self.inner.updated
|
||||
+ Duration::from_secs(self.inner.keepalive_timeout as u64);
|
||||
let prev = self.inner.ka_updated.get() + self.inner.ka();
|
||||
if timeout == 0 {
|
||||
self.inner.timer.unregister(prev, &self.inner.state);
|
||||
} else {
|
||||
let expire = self.inner.updated + Duration::from_secs(timeout as u64);
|
||||
let expire =
|
||||
self.inner.ka_updated.get() + Duration::from_secs(timeout as u64);
|
||||
self.inner.timer.register(expire, prev, &self.inner.state);
|
||||
}
|
||||
self.inner.keepalive_timeout = timeout;
|
||||
self.inner.ka_timeout = timeout;
|
||||
|
||||
self
|
||||
}
|
||||
|
@ -191,71 +185,14 @@ where
|
|||
U: Encoder + Decoder,
|
||||
<U as Encoder>::Item: 'static,
|
||||
{
|
||||
fn handle_result(
|
||||
&self,
|
||||
item: Result<S::Response, S::Error>,
|
||||
state: &State,
|
||||
wake: bool,
|
||||
) {
|
||||
fn handle_result(&self, item: Result<S::Response, S::Error>, state: &State) {
|
||||
self.inflight.set(self.inflight.get() - 1);
|
||||
if let Err(err) = state.write_result(item, &self.codec) {
|
||||
self.error.set(Some(err.into()));
|
||||
}
|
||||
|
||||
if wake {
|
||||
state.dsp_wake_task()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, U> DispatcherInner<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
|
||||
U: Decoder + Encoder,
|
||||
{
|
||||
fn take_error(&self) -> Option<DispatchItem<U>> {
|
||||
// check for errors
|
||||
self.shared
|
||||
.error
|
||||
.take()
|
||||
.and_then(|err| err.convert())
|
||||
.or_else(|| self.state.take_io_error().map(DispatchItem::IoError))
|
||||
}
|
||||
|
||||
/// check keepalive timeout
|
||||
fn check_keepalive(&self) {
|
||||
if self.state.is_keepalive() {
|
||||
log::trace!("keepalive timeout");
|
||||
if let Some(err) = self.shared.error.take() {
|
||||
self.shared.error.set(Some(err));
|
||||
} else {
|
||||
self.shared.error.set(Some(DispatcherError::KeepAlive));
|
||||
}
|
||||
self.state.dsp_mark_stopped();
|
||||
}
|
||||
}
|
||||
|
||||
/// update keep-alive timer
|
||||
fn update_keepalive(&mut self) {
|
||||
if self.keepalive_timeout != 0 {
|
||||
let updated = self.timer.now();
|
||||
if updated != self.updated {
|
||||
let ka = Duration::from_secs(self.keepalive_timeout as u64);
|
||||
self.timer
|
||||
.register(updated + ka, self.updated + ka, &self.state);
|
||||
self.updated = updated;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// unregister keep-alive timer
|
||||
fn unregister_keepalive(&self) {
|
||||
if self.keepalive_timeout != 0 {
|
||||
self.timer.unregister(
|
||||
self.updated + Duration::from_secs(self.keepalive_timeout as u64),
|
||||
&self.state,
|
||||
);
|
||||
match state.write_result(item, &self.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => state.enable_write_backpressure(),
|
||||
Err(err) => self.error.set(Some(err.into())),
|
||||
}
|
||||
state.dsp_wake_task();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -269,144 +206,132 @@ where
|
|||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.as_mut().project();
|
||||
let slf = &this.inner;
|
||||
let state = &slf.state;
|
||||
|
||||
// handle service response future
|
||||
if let Some(fut) = this.response.as_mut().as_pin_mut() {
|
||||
if let Some(fut) = this.fut.as_mut().as_pin_mut() {
|
||||
match fut.poll(cx) {
|
||||
Poll::Pending => (),
|
||||
Poll::Ready(item) => {
|
||||
this.inner
|
||||
.shared
|
||||
.handle_result(item, &this.inner.state, false);
|
||||
this.response.set(None);
|
||||
this.fut.set(None);
|
||||
slf.shared.inflight.set(slf.shared.inflight.get() - 1);
|
||||
let _ = slf.handle_result(item, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match this.inner.st {
|
||||
DispatcherState::Processing => {
|
||||
loop {
|
||||
match this.service.poll_ready(cx) {
|
||||
Poll::Ready(Ok(_)) => {
|
||||
let mut retry = false;
|
||||
|
||||
// service is ready, wake io read task
|
||||
this.inner.state.dsp_restart_read_task();
|
||||
|
||||
// check keepalive timeout
|
||||
this.inner.check_keepalive();
|
||||
|
||||
let item = if this.inner.state.is_dsp_stopped() {
|
||||
log::trace!("dispatcher is instructed to stop");
|
||||
|
||||
// unregister keep-alive timer
|
||||
this.inner.unregister_keepalive();
|
||||
|
||||
// check for errors
|
||||
retry = true;
|
||||
this.inner.st = DispatcherState::Stop;
|
||||
this.inner.take_error()
|
||||
} else {
|
||||
// decode incoming bytes stream
|
||||
if this.inner.state.is_read_ready() {
|
||||
let item = this
|
||||
.inner
|
||||
.state
|
||||
.decode_item(&this.inner.shared.codec);
|
||||
match item {
|
||||
Ok(Some(el)) => {
|
||||
this.inner.update_keepalive();
|
||||
Some(DispatchItem::Item(el))
|
||||
}
|
||||
Ok(None) => {
|
||||
log::trace!("not enough data to decode next frame, register dispatch task");
|
||||
this.inner
|
||||
.state
|
||||
.dsp_read_more_data(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
Err(err) => {
|
||||
retry = true;
|
||||
this.inner.st = DispatcherState::Stop;
|
||||
this.inner.unregister_keepalive();
|
||||
Some(DispatchItem::DecoderError(err))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.inner.state.dsp_register_task(cx.waker());
|
||||
return Poll::Pending;
|
||||
match slf.st.get() {
|
||||
DispatcherState::WrEnabled => {
|
||||
let item = match ready!(slf.poll_service(&this.service, cx)) {
|
||||
PollService::Ready => {
|
||||
slf.st.set(DispatcherState::WrWaitReady);
|
||||
DispatchItem::WBackPressureEnabled
|
||||
}
|
||||
PollService::Item(item) => item,
|
||||
PollService::ServiceError => continue,
|
||||
};
|
||||
|
||||
// call service
|
||||
if let Some(item) = item {
|
||||
// optimize first call
|
||||
if this.response.is_none() {
|
||||
this.response.set(Some(this.service.call(item)));
|
||||
let res = this
|
||||
.response
|
||||
.as_mut()
|
||||
.as_pin_mut()
|
||||
.unwrap()
|
||||
.poll(cx);
|
||||
|
||||
if let Poll::Ready(res) = res {
|
||||
if let Err(err) = this
|
||||
.inner
|
||||
.state
|
||||
.write_result(res, &this.inner.shared.codec)
|
||||
{
|
||||
this.inner
|
||||
.shared
|
||||
.error
|
||||
.set(Some(err.into()));
|
||||
}
|
||||
this.response.set(None);
|
||||
} else {
|
||||
this.inner
|
||||
.shared
|
||||
.inflight
|
||||
.set(this.inner.shared.inflight.get() + 1);
|
||||
}
|
||||
} else {
|
||||
this.inner
|
||||
.shared
|
||||
.inflight
|
||||
.set(this.inner.shared.inflight.get() + 1);
|
||||
let st = this.inner.state.clone();
|
||||
let shared = this.inner.shared.clone();
|
||||
crate::rt::spawn(this.service.call(item).map(
|
||||
move |item| {
|
||||
shared.handle_result(item, &st, true);
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// run again
|
||||
if retry {
|
||||
return self.poll(cx);
|
||||
}
|
||||
if this.fut.is_none() {
|
||||
// optimize first service call
|
||||
this.fut.set(Some(this.service.call(item)));
|
||||
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
|
||||
Poll::Ready(res) => {
|
||||
this.fut.set(None);
|
||||
ready!(slf.handle_result(res, cx));
|
||||
}
|
||||
Poll::Pending => {
|
||||
// pause io read task
|
||||
log::trace!("service is not ready, register dispatch task");
|
||||
this.inner.state.dsp_service_not_ready(cx.waker());
|
||||
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slf.spawn_service_call(this.service.call(item));
|
||||
}
|
||||
}
|
||||
DispatcherState::WrWaitReady => {
|
||||
let item = match ready!(slf.poll_service(&this.service, cx)) {
|
||||
PollService::Ready => {
|
||||
if state.is_write_backpressure_disabled() {
|
||||
slf.st.set(DispatcherState::Processing);
|
||||
DispatchItem::WBackPressureDisabled
|
||||
} else {
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
// handle service readiness error
|
||||
log::trace!("service readiness check failed, stopping");
|
||||
this.inner.st = DispatcherState::Stop;
|
||||
this.inner.state.dsp_mark_stopped();
|
||||
this.inner
|
||||
.shared
|
||||
.error
|
||||
.set(Some(DispatcherError::Service(err)));
|
||||
this.inner.unregister_keepalive();
|
||||
return self.poll(cx);
|
||||
}
|
||||
PollService::Item(item) => item,
|
||||
PollService::ServiceError => continue,
|
||||
};
|
||||
|
||||
// call service
|
||||
if this.fut.is_none() {
|
||||
// optimize first service call
|
||||
this.fut.set(Some(this.service.call(item)));
|
||||
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
|
||||
Poll::Ready(res) => {
|
||||
this.fut.set(None);
|
||||
ready!(slf.handle_result(res, cx));
|
||||
}
|
||||
Poll::Pending => {
|
||||
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slf.spawn_service_call(this.service.call(item));
|
||||
}
|
||||
}
|
||||
DispatcherState::Processing => {
|
||||
let item = match ready!(slf.poll_service(&this.service, cx)) {
|
||||
PollService::Ready => {
|
||||
if state.is_write_backpressure_enabled() {
|
||||
// instruct write task to notify dispatcher when data is flushed
|
||||
state.dsp_enable_write_backpressure(cx.waker());
|
||||
slf.st.set(DispatcherState::WrWaitReady);
|
||||
DispatchItem::WBackPressureEnabled
|
||||
} else if state.is_read_ready() {
|
||||
// decode incoming bytes if buffer is ready
|
||||
match state.decode_item(&slf.shared.codec) {
|
||||
Ok(Some(el)) => {
|
||||
slf.update_keepalive();
|
||||
DispatchItem::Item(el)
|
||||
}
|
||||
Ok(None) => {
|
||||
log::trace!("not enough data to decode next frame, register dispatch task");
|
||||
state.dsp_read_more_data(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
Err(err) => {
|
||||
slf.st.set(DispatcherState::Stop);
|
||||
slf.unregister_keepalive();
|
||||
DispatchItem::DecoderError(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// no new events
|
||||
state.dsp_register_task(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
PollService::Item(item) => item,
|
||||
PollService::ServiceError => continue,
|
||||
};
|
||||
|
||||
// call service
|
||||
if this.fut.is_none() {
|
||||
// optimize first service call
|
||||
this.fut.set(Some(this.service.call(item)));
|
||||
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
|
||||
Poll::Ready(res) => {
|
||||
this.fut.set(None);
|
||||
ready!(slf.handle_result(res, cx));
|
||||
}
|
||||
Poll::Pending => {
|
||||
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slf.spawn_service_call(this.service.call(item));
|
||||
}
|
||||
}
|
||||
// drain service responses
|
||||
|
@ -414,34 +339,176 @@ where
|
|||
// service may relay on poll_ready for response results
|
||||
let _ = this.service.poll_ready(cx);
|
||||
|
||||
if this.inner.shared.inflight.get() == 0 {
|
||||
this.inner.state.shutdown_io();
|
||||
this.inner.st = DispatcherState::Shutdown;
|
||||
self.poll(cx)
|
||||
if slf.shared.inflight.get() == 0 {
|
||||
slf.st.set(DispatcherState::Shutdown);
|
||||
state.shutdown_io();
|
||||
} else {
|
||||
this.inner.state.dsp_register_task(cx.waker());
|
||||
Poll::Pending
|
||||
state.dsp_register_task(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
// shutdown service
|
||||
DispatcherState::Shutdown => {
|
||||
let err = this.inner.shared.error.take();
|
||||
let err = slf.error.take();
|
||||
|
||||
if this.service.poll_shutdown(cx, err.is_some()).is_ready() {
|
||||
return if this.service.poll_shutdown(cx, err.is_some()).is_ready() {
|
||||
log::trace!("service shutdown is completed, stop");
|
||||
|
||||
Poll::Ready(if let Some(DispatcherError::Service(err)) = err {
|
||||
Poll::Ready(if let Some(err) = err {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(())
|
||||
})
|
||||
} else {
|
||||
this.inner.shared.error.set(err);
|
||||
slf.error.set(err);
|
||||
Poll::Pending
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, U> DispatcherInner<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
|
||||
U: Decoder + Encoder + 'static,
|
||||
{
|
||||
/// spawn service call
|
||||
fn spawn_service_call(&self, fut: S::Future) {
|
||||
self.shared.inflight.set(self.shared.inflight.get() + 1);
|
||||
|
||||
let st = self.state.clone();
|
||||
let shared = self.shared.clone();
|
||||
crate::rt::spawn(fut.map(move |item| shared.handle_result(item, &st)));
|
||||
}
|
||||
|
||||
fn handle_result(
|
||||
&self,
|
||||
item: Result<Option<<U as Encoder>::Item>, S::Error>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<()> {
|
||||
match self.state.write_result(item, &self.shared.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => {
|
||||
// instruct write task to notify dispatcher when data is flushed
|
||||
self.state.dsp_enable_write_backpressure(cx.waker());
|
||||
self.st.set(DispatcherState::WrEnabled);
|
||||
return Poll::Pending;
|
||||
}
|
||||
Err(Either::Left(err)) => {
|
||||
self.error.set(Some(err));
|
||||
}
|
||||
Err(Either::Right(err)) => {
|
||||
self.shared.error.set(Some(DispatcherError::Encoder(err)))
|
||||
}
|
||||
}
|
||||
Poll::Ready(())
|
||||
}
|
||||
|
||||
fn poll_service(&self, srv: &S, cx: &mut Context<'_>) -> Poll<PollService<U>> {
|
||||
match srv.poll_ready(cx) {
|
||||
Poll::Ready(Ok(_)) => {
|
||||
// service is ready, wake io read task
|
||||
self.state.dsp_restart_read_task();
|
||||
|
||||
// check keepalive timeout
|
||||
self.check_keepalive();
|
||||
|
||||
// check for errors
|
||||
Poll::Ready(if let Some(err) = self.shared.error.take() {
|
||||
log::trace!("error occured, stopping dispatcher");
|
||||
self.unregister_keepalive();
|
||||
self.st.set(DispatcherState::Stop);
|
||||
|
||||
match err {
|
||||
DispatcherError::KeepAlive => {
|
||||
PollService::Item(DispatchItem::KeepAliveTimeout)
|
||||
}
|
||||
DispatcherError::Encoder(err) => {
|
||||
PollService::Item(DispatchItem::EncoderError(err))
|
||||
}
|
||||
DispatcherError::Service(err) => {
|
||||
self.error.set(Some(err));
|
||||
PollService::ServiceError
|
||||
}
|
||||
}
|
||||
} else if self.state.is_dsp_stopped() {
|
||||
log::trace!("dispatcher is instructed to stop");
|
||||
|
||||
self.unregister_keepalive();
|
||||
self.st.set(DispatcherState::Stop);
|
||||
|
||||
// get io error
|
||||
if let Some(err) = self.state.take_io_error() {
|
||||
PollService::Item(DispatchItem::IoError(err))
|
||||
} else {
|
||||
PollService::ServiceError
|
||||
}
|
||||
} else {
|
||||
PollService::Ready
|
||||
})
|
||||
}
|
||||
// pause io read task
|
||||
Poll::Pending => {
|
||||
log::trace!("service is not ready, register dispatch task");
|
||||
self.state.dsp_service_not_ready(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
// handle service readiness error
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::trace!("service readiness check failed, stopping");
|
||||
self.st.set(DispatcherState::Stop);
|
||||
self.error.set(Some(err));
|
||||
self.unregister_keepalive();
|
||||
Poll::Ready(PollService::ServiceError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ka(&self) -> Duration {
|
||||
Duration::from_secs(self.ka_timeout as u64)
|
||||
}
|
||||
|
||||
fn ka_enabled(&self) -> bool {
|
||||
self.ka_timeout > 0
|
||||
}
|
||||
|
||||
/// check keepalive timeout
|
||||
fn check_keepalive(&self) {
|
||||
if self.state.is_keepalive() {
|
||||
log::trace!("keepalive timeout");
|
||||
if let Some(err) = self.shared.error.take() {
|
||||
self.shared.error.set(Some(err));
|
||||
} else {
|
||||
self.shared.error.set(Some(DispatcherError::KeepAlive));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// update keep-alive timer
|
||||
fn update_keepalive(&self) {
|
||||
if self.ka_enabled() {
|
||||
let updated = self.timer.now();
|
||||
if updated != self.ka_updated.get() {
|
||||
let ka = self.ka();
|
||||
self.timer.register(
|
||||
updated + ka,
|
||||
self.ka_updated.get() + ka,
|
||||
&self.state,
|
||||
);
|
||||
self.ka_updated.set(updated);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// unregister keep-alive timer
|
||||
fn unregister_keepalive(&self) {
|
||||
if self.ka_enabled() {
|
||||
self.timer
|
||||
.unregister(self.ka_updated.get() + self.ka(), &self.state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -473,8 +540,8 @@ mod tests {
|
|||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
let timer = Timer::default();
|
||||
let keepalive_timeout = 30;
|
||||
let updated = timer.now();
|
||||
let ka_timeout = 30;
|
||||
let ka_updated = timer.now();
|
||||
let state = State::new();
|
||||
let io = Rc::new(RefCell::new(io));
|
||||
let shared = Rc::new(DispatcherShared {
|
||||
|
@ -489,14 +556,15 @@ mod tests {
|
|||
(
|
||||
Dispatcher {
|
||||
service: service.into_service(),
|
||||
response: None,
|
||||
fut: None,
|
||||
inner: DispatcherInner {
|
||||
shared,
|
||||
timer,
|
||||
updated,
|
||||
keepalive_timeout,
|
||||
ka_timeout,
|
||||
ka_updated: Cell::new(ka_updated),
|
||||
state: state.clone(),
|
||||
st: DispatcherState::Processing,
|
||||
error: Cell::new(None),
|
||||
st: Cell::new(DispatcherState::Processing),
|
||||
},
|
||||
},
|
||||
state,
|
||||
|
|
|
@ -17,6 +17,10 @@ use crate::codec::{Decoder, Encoder};
|
|||
/// Framed transport item
|
||||
pub enum DispatchItem<U: Encoder + Decoder> {
|
||||
Item(<U as Decoder>::Item),
|
||||
/// Write back-pressure enabled
|
||||
WBackPressureEnabled,
|
||||
/// Write back-pressure disabled
|
||||
WBackPressureDisabled,
|
||||
/// Keep alive timeout
|
||||
KeepAliveTimeout,
|
||||
/// Decoder parse error
|
||||
|
@ -37,6 +41,12 @@ where
|
|||
DispatchItem::Item(ref item) => {
|
||||
write!(fmt, "DispatchItem::Item({:?})", item)
|
||||
}
|
||||
DispatchItem::WBackPressureEnabled => {
|
||||
write!(fmt, "DispatchItem::WBackPressureEnabled")
|
||||
}
|
||||
DispatchItem::WBackPressureDisabled => {
|
||||
write!(fmt, "DispatchItem::WBackPressureDisabled")
|
||||
}
|
||||
DispatchItem::KeepAliveTimeout => {
|
||||
write!(fmt, "DispatchItem::KeepAliveTimeout")
|
||||
}
|
||||
|
@ -52,3 +62,28 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codec::BytesCodec;
|
||||
|
||||
#[test]
|
||||
fn test_fmt() {
|
||||
type T = DispatchItem<BytesCodec>;
|
||||
|
||||
let err = T::EncoderError(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatchItem::Encoder"));
|
||||
let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatchItem::Decoder"));
|
||||
let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatchItem::IoError"));
|
||||
|
||||
assert!(format!("{:?}", T::WBackPressureEnabled)
|
||||
.contains("DispatchItem::WBackPressureEnabled"));
|
||||
assert!(format!("{:?}", T::WBackPressureDisabled)
|
||||
.contains("DispatchItem::WBackPressureDisabled"));
|
||||
assert!(format!("{:?}", T::KeepAliveTimeout)
|
||||
.contains("DispatchItem::KeepAliveTimeout"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,29 +10,31 @@ use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}
|
|||
use crate::framed::write::flush;
|
||||
use crate::task::LocalWaker;
|
||||
|
||||
const HW: usize = 8 * 1024;
|
||||
const HW: usize = 16 * 1024;
|
||||
|
||||
bitflags::bitflags! {
|
||||
pub struct Flags: u16 {
|
||||
const DSP_STOP = 0b0000_0001;
|
||||
const DSP_KEEPALIVE = 0b0000_0010;
|
||||
const DSP_STOP = 0b0000_0000_0001;
|
||||
const DSP_KEEPALIVE = 0b0000_0000_0010;
|
||||
|
||||
/// io error occured
|
||||
const IO_ERR = 0b0000_0100;
|
||||
const IO_ERR = 0b0000_0000_0100;
|
||||
/// stop io tasks
|
||||
const IO_STOP = 0b0000_1000;
|
||||
const IO_STOP = 0b0000_0000_1000;
|
||||
/// shutdown io tasks
|
||||
const IO_SHUTDOWN = 0b0001_0000;
|
||||
const IO_SHUTDOWN = 0b0000_0001_0000;
|
||||
|
||||
/// pause io read
|
||||
const RD_PAUSED = 0b0010_0000;
|
||||
const RD_PAUSED = 0b0000_0010_0000;
|
||||
/// new data is available
|
||||
const RD_READY = 0b0100_0000;
|
||||
const RD_READY = 0b0000_0100_0000;
|
||||
|
||||
/// write task is ready
|
||||
const WR_READY = 0b0001_0000_0000;
|
||||
/// write buffer is full
|
||||
const WR_NOT_READY = 0b1000_0000;
|
||||
const WR_NOT_READY = 0b0010_0000_0000;
|
||||
|
||||
const ST_DSP_ERR = 0b10000_0000;
|
||||
const ST_DSP_ERR = 0b0001_0000_0000_0000;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,14 +165,43 @@ impl State {
|
|||
self.0.flags.get().contains(Flags::RD_READY)
|
||||
}
|
||||
|
||||
/// read task must be paused if service is not ready (RD_PAUSED)
|
||||
pub(super) fn is_read_paused(&self) -> bool {
|
||||
self.0.flags.get().contains(Flags::RD_PAUSED)
|
||||
self.0.flags.get().intersects(Flags::RD_PAUSED)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if write buffer is ready
|
||||
pub fn is_write_ready(&self) -> bool {
|
||||
!self.0.flags.get().contains(Flags::WR_NOT_READY)
|
||||
/// Check if write back-pressure is disabled
|
||||
pub fn is_write_backpressure_disabled(&self) -> bool {
|
||||
let mut flags = self.0.flags.get();
|
||||
if flags.contains(Flags::WR_READY) {
|
||||
flags.remove(Flags::WR_READY);
|
||||
self.0.flags.set(flags);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if write back-pressure is enabled
|
||||
pub fn is_write_backpressure_enabled(&self) -> bool {
|
||||
let mut flags = self.0.flags.get();
|
||||
if flags.contains(Flags::WR_READY) {
|
||||
flags.remove(Flags::WR_READY);
|
||||
self.0.flags.set(flags);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Enable write back-persurre
|
||||
pub fn enable_write_backpressure(&self) {
|
||||
let mut flags = self.0.flags.get();
|
||||
flags.insert(Flags::WR_NOT_READY);
|
||||
self.0.flags.set(flags);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -272,6 +303,7 @@ impl State {
|
|||
let mut flags = self.0.flags.get();
|
||||
if flags.contains(Flags::WR_NOT_READY) {
|
||||
flags.remove(Flags::WR_NOT_READY);
|
||||
flags.insert(Flags::WR_READY);
|
||||
self.0.flags.set(flags);
|
||||
self.0.dispatch_task.wake();
|
||||
}
|
||||
|
@ -298,16 +330,24 @@ impl State {
|
|||
self.0.dispatch_task.register(waker);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if write buff is full
|
||||
pub fn is_write_buf_full(&self) -> bool {
|
||||
self.0.write_buf.borrow().len() >= HW
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Wait until write task flushes data to socket
|
||||
pub fn dsp_flush_write_data(&self, waker: &Waker) {
|
||||
///
|
||||
/// Write task must be waken up separately.
|
||||
pub fn dsp_enable_write_backpressure(&self, waker: &Waker) {
|
||||
let mut flags = self.0.flags.get();
|
||||
flags.insert(Flags::WR_NOT_READY);
|
||||
self.0.flags.set(flags);
|
||||
self.0.write_task.wake();
|
||||
self.0.dispatch_task.register(waker);
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[inline]
|
||||
/// Mark dispatcher as stopped
|
||||
pub fn dsp_mark_stopped(&self) {
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::codec::{AsyncRead, AsyncWrite};
|
|||
use crate::framed::State;
|
||||
use crate::rt::time::{delay_for, Delay};
|
||||
|
||||
const HW: usize = 8 * 1024;
|
||||
const HW: usize = 16 * 1024;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum IoWriteState {
|
||||
|
|
|
@ -1,20 +1,17 @@
|
|||
//! Websockets client
|
||||
use std::convert::TryFrom;
|
||||
use std::net::SocketAddr;
|
||||
use std::rc::Rc;
|
||||
use std::{fmt, str};
|
||||
use std::{convert::TryFrom, fmt, net::SocketAddr, rc::Rc, str};
|
||||
|
||||
#[cfg(feature = "cookie")]
|
||||
use coo_kie::{Cookie, CookieJar};
|
||||
use futures::Stream;
|
||||
use futures::future::{err, ok, Either};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite, Framed};
|
||||
use crate::framed::{DispatchItem, Dispatcher, State};
|
||||
use crate::http::error::HttpError;
|
||||
use crate::http::header::{self, HeaderName, HeaderValue, AUTHORIZATION};
|
||||
use crate::http::{ConnectionType, Payload, RequestHead, StatusCode, Uri};
|
||||
use crate::rt::time::timeout;
|
||||
use crate::service::{IntoService, Service};
|
||||
use crate::util::framed::{Dispatcher, DispatcherError};
|
||||
use crate::service::{apply_fn, IntoService, Service};
|
||||
use crate::ws;
|
||||
|
||||
pub use crate::ws::{CloseCode, CloseReason, Frame, Message};
|
||||
|
@ -221,9 +218,7 @@ impl WsRequest {
|
|||
}
|
||||
|
||||
/// Complete request construction and connect to a websockets server.
|
||||
pub async fn connect(
|
||||
mut self,
|
||||
) -> Result<(ClientResponse, Framed<BoxedSocket, ws::Codec>), WsClientError> {
|
||||
pub async fn connect(mut self) -> Result<WsConnection, WsClientError> {
|
||||
if let Some(e) = self.err.take() {
|
||||
return Err(WsClientError::from(e));
|
||||
}
|
||||
|
@ -378,8 +373,8 @@ impl WsRequest {
|
|||
return Err(WsClientError::MissingWebSocketAcceptHeader);
|
||||
};
|
||||
|
||||
// response and ws framed
|
||||
Ok((
|
||||
// response and ws io
|
||||
Ok(WsConnection::new(
|
||||
ClientResponse::new(head, Payload::None),
|
||||
framed.map_codec(|_| {
|
||||
if server_mode {
|
||||
|
@ -407,21 +402,70 @@ impl fmt::Debug for WsRequest {
|
|||
}
|
||||
}
|
||||
|
||||
/// Start client websockets service.
|
||||
pub async fn start<Io, T, F, Rx>(
|
||||
framed: Framed<Io, ws::Codec>,
|
||||
rx: Rx,
|
||||
service: F,
|
||||
) -> Result<(), DispatcherError<T::Error, ws::Codec>>
|
||||
pub struct WsConnection<Io = BoxedSocket> {
|
||||
io: Io,
|
||||
state: State,
|
||||
codec: ws::Codec,
|
||||
res: ClientResponse,
|
||||
}
|
||||
|
||||
impl<Io> WsConnection<Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
T: Service<Request = ws::Frame, Response = Option<ws::Message>>,
|
||||
T::Error: 'static,
|
||||
T::Future: 'static,
|
||||
F: IntoService<T>,
|
||||
Rx: Stream<Item = ws::Message> + Unpin + 'static,
|
||||
{
|
||||
Dispatcher::with(framed, Some(rx), service.into_service()).await
|
||||
fn new(res: ClientResponse, framed: Framed<Io, ws::Codec>) -> Self {
|
||||
let (io, codec, state) = State::from_framed(framed);
|
||||
|
||||
Self {
|
||||
io,
|
||||
codec,
|
||||
state,
|
||||
res,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get ws sink
|
||||
pub fn sink(&self) -> ws::WsSink {
|
||||
ws::WsSink::new(self.state.clone(), self.codec.clone())
|
||||
}
|
||||
|
||||
/// Get reference to response
|
||||
pub fn response(&self) -> &ClientResponse {
|
||||
&self.res
|
||||
}
|
||||
|
||||
/// Start client websockets service.
|
||||
pub async fn start<T, F, Rx>(self, service: F) -> Result<(), ws::WsError<T::Error>>
|
||||
where
|
||||
T: Service<Request = ws::Frame, Response = Option<ws::Message>> + 'static,
|
||||
F: IntoService<T>,
|
||||
{
|
||||
let service = apply_fn(
|
||||
service.into_service().map_err(ws::WsError::Service),
|
||||
|req, srv| match req {
|
||||
DispatchItem::Item(item) => Either::Left(srv.call(item)),
|
||||
DispatchItem::WBackPressureEnabled
|
||||
| DispatchItem::WBackPressureDisabled => Either::Right(ok(None)),
|
||||
DispatchItem::KeepAliveTimeout => {
|
||||
Either::Right(err(ws::WsError::KeepAlive))
|
||||
}
|
||||
DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => {
|
||||
Either::Right(err(ws::WsError::Protocol(e)))
|
||||
}
|
||||
DispatchItem::IoError(e) => Either::Right(err(ws::WsError::Io(e))),
|
||||
},
|
||||
);
|
||||
|
||||
Dispatcher::new(self.io, self.codec, self.state, service, Default::default())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Consumes the `WsConnection`, returning it'as underlying I/O framed object
|
||||
/// and response.
|
||||
pub fn into_inner(self) -> (ClientResponse, Framed<Io, ws::Codec>) {
|
||||
let framed = self.state.into_framed(self.io, self.codec);
|
||||
(self.res, framed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -9,11 +9,8 @@ pub use actix_threadpool::BlockingError;
|
|||
pub use futures::channel::oneshot::Canceled;
|
||||
pub use http::Error as HttpError;
|
||||
|
||||
use crate::codec::{Decoder, Encoder};
|
||||
use crate::util::framed::DispatcherError;
|
||||
|
||||
use super::body::Body;
|
||||
use super::response::Response;
|
||||
use crate::http::body::Body;
|
||||
use crate::http::response::Response;
|
||||
|
||||
/// Error that can be converted to `Response`
|
||||
pub trait ResponseError: fmt::Display + fmt::Debug {
|
||||
|
@ -60,14 +57,6 @@ impl ResponseError for io::Error {}
|
|||
/// `InternalServerError` for `JsonError`
|
||||
impl ResponseError for serde_json::error::Error {}
|
||||
|
||||
impl<E, U: Encoder + Decoder + 'static> ResponseError for DispatcherError<E, U>
|
||||
where
|
||||
E: fmt::Debug + fmt::Display + 'static,
|
||||
<U as Encoder>::Error: fmt::Debug,
|
||||
<U as Decoder>::Error: fmt::Debug,
|
||||
{
|
||||
}
|
||||
|
||||
/// A set of errors that can occur during parsing HTTP streams
|
||||
#[derive(Debug, Display, From)]
|
||||
pub enum ParseError {
|
||||
|
|
|
@ -402,7 +402,9 @@ where
|
|||
*this.st = st;
|
||||
}
|
||||
WritePayloadStatus::Pause => {
|
||||
this.inner.state.dsp_flush_write_data(cx.waker());
|
||||
this.inner
|
||||
.state
|
||||
.dsp_enable_write_backpressure(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
WritePayloadStatus::Continue => (),
|
||||
|
|
|
@ -328,7 +328,7 @@ impl TestServer {
|
|||
{
|
||||
let url = self.url(path);
|
||||
let connect = self.client.ws(url).connect();
|
||||
connect.await.map(|(_, framed)| framed)
|
||||
connect.await.map(|ws| ws.into_inner().1)
|
||||
}
|
||||
|
||||
/// Connect to a websocket server
|
||||
|
|
|
@ -1,554 +0,0 @@
|
|||
//! Framed transport dispatcher
|
||||
use std::{fmt, io, pin::Pin, task::Context, task::Poll, time::Duration};
|
||||
|
||||
use either::Either;
|
||||
use futures::{ready, Future, FutureExt, Stream};
|
||||
use log::debug;
|
||||
|
||||
use crate::channel::mpsc;
|
||||
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
|
||||
use crate::rt::time::{delay_for, Delay};
|
||||
use crate::service::{IntoService, Service};
|
||||
|
||||
type Request<U> = <U as Decoder>::Item;
|
||||
type Response<U> = <U as Encoder>::Item;
|
||||
|
||||
/// Framed transport errors
|
||||
pub enum DispatcherError<E, U: Encoder + Decoder> {
|
||||
/// Inner service error
|
||||
Service(E),
|
||||
/// Encoder parse error
|
||||
Encoder(<U as Encoder>::Error),
|
||||
/// Decoder parse error
|
||||
Decoder(<U as Decoder>::Error),
|
||||
/// Unexpected io error
|
||||
IoError(io::Error),
|
||||
}
|
||||
|
||||
impl<E, U: Encoder + Decoder> From<E> for DispatcherError<E, U> {
|
||||
fn from(err: E) -> Self {
|
||||
DispatcherError::Service(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, U: Encoder + Decoder> From<Either<E, io::Error>> for DispatcherError<E, U> {
|
||||
fn from(err: Either<E, io::Error>) -> Self {
|
||||
match err {
|
||||
Either::Left(err) => DispatcherError::Service(err),
|
||||
Either::Right(err) => DispatcherError::IoError(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, U: Encoder + Decoder> fmt::Debug for DispatcherError<E, U>
|
||||
where
|
||||
E: fmt::Debug,
|
||||
<U as Encoder>::Error: fmt::Debug,
|
||||
<U as Decoder>::Error: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
DispatcherError::Service(ref e) => {
|
||||
write!(fmt, "DispatcherError::Service({:?})", e)
|
||||
}
|
||||
DispatcherError::Encoder(ref e) => {
|
||||
write!(fmt, "DispatcherError::Encoder({:?})", e)
|
||||
}
|
||||
DispatcherError::Decoder(ref e) => {
|
||||
write!(fmt, "DispatcherError::Decoder({:?})", e)
|
||||
}
|
||||
DispatcherError::IoError(ref e) => {
|
||||
write!(fmt, "DispatcherError::IoError({:?})", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E, U: Encoder + Decoder> fmt::Display for DispatcherError<E, U>
|
||||
where
|
||||
E: fmt::Display,
|
||||
<U as Encoder>::Error: fmt::Debug,
|
||||
<U as Decoder>::Error: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
DispatcherError::Service(ref e) => write!(fmt, "{}", e),
|
||||
DispatcherError::Encoder(ref e) => write!(fmt, "{:?}", e),
|
||||
DispatcherError::Decoder(ref e) => write!(fmt, "{:?}", e),
|
||||
DispatcherError::IoError(ref e) => write!(fmt, "{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// FramedTransport - is a future that reads frames from Framed object
|
||||
/// and pass then to the service.
|
||||
pub struct Dispatcher<S, T, U, Out>
|
||||
where
|
||||
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
|
||||
S::Error: 'static,
|
||||
S::Future: 'static,
|
||||
T: AsyncRead,
|
||||
T: AsyncWrite,
|
||||
T: Unpin,
|
||||
U: Encoder,
|
||||
U: Decoder,
|
||||
<U as Encoder>::Item: 'static,
|
||||
<U as Encoder>::Error: std::fmt::Debug,
|
||||
Out: Stream<Item = <U as Encoder>::Item>,
|
||||
Out: Unpin,
|
||||
{
|
||||
inner: InnerDispatcher<S, T, U, Out>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T, U> Dispatcher<S, T, U, mpsc::Receiver<<U as Encoder>::Item>>
|
||||
where
|
||||
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
|
||||
S::Error: 'static,
|
||||
S::Future: 'static,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
U: Decoder + Encoder,
|
||||
<U as Encoder>::Item: 'static,
|
||||
<U as Encoder>::Error: std::fmt::Debug,
|
||||
{
|
||||
/// Construct new `Dispatcher` instance
|
||||
pub fn new<F: IntoService<S>>(framed: Framed<T, U>, service: F) -> Self {
|
||||
Dispatcher {
|
||||
inner: InnerDispatcher {
|
||||
framed,
|
||||
sink: None,
|
||||
rx: mpsc::channel().1,
|
||||
service: service.into_service(),
|
||||
state: FramedState::Processing,
|
||||
disconnect_timeout: 1000,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T, U, In> Dispatcher<S, T, U, In>
|
||||
where
|
||||
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
|
||||
S::Error: 'static,
|
||||
S::Future: 'static,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
U: Decoder + Encoder,
|
||||
<U as Encoder>::Item: 'static,
|
||||
<U as Encoder>::Error: std::fmt::Debug,
|
||||
In: Stream<Item = <U as Encoder>::Item> + Unpin,
|
||||
{
|
||||
/// Construct new `Dispatcher` instance with outgoing messages stream.
|
||||
pub fn with<F: IntoService<S>>(
|
||||
framed: Framed<T, U>,
|
||||
sink: Option<In>,
|
||||
service: F,
|
||||
) -> Self {
|
||||
Dispatcher {
|
||||
inner: InnerDispatcher {
|
||||
framed,
|
||||
sink,
|
||||
rx: mpsc::channel().1,
|
||||
service: service.into_service(),
|
||||
state: FramedState::Processing,
|
||||
disconnect_timeout: 1000,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Set connection disconnect timeout in milliseconds.
|
||||
///
|
||||
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
|
||||
/// within this time, the connection get dropped.
|
||||
///
|
||||
/// To disable timeout set value to 0.
|
||||
///
|
||||
/// By default disconnect timeout is set to 1 seconds.
|
||||
pub fn disconnect_timeout(mut self, val: u64) -> Self {
|
||||
self.inner.disconnect_timeout = val;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T, U, In> Future for Dispatcher<S, T, U, In>
|
||||
where
|
||||
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
|
||||
S::Error: 'static,
|
||||
S::Future: 'static,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
U: Decoder + Encoder,
|
||||
<U as Encoder>::Item: 'static,
|
||||
<U as Encoder>::Error: std::fmt::Debug,
|
||||
In: Stream<Item = <U as Encoder>::Item> + Unpin,
|
||||
{
|
||||
type Output = Result<(), DispatcherError<S::Error, U>>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.project().inner.poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
enum FramedState<S: Service, U: Encoder + Decoder> {
|
||||
Processing,
|
||||
FlushAndStop(Option<DispatcherError<S::Error, U>>),
|
||||
Shutdown(Option<DispatcherError<S::Error, U>>),
|
||||
ShutdownIo(Delay, Option<Result<(), DispatcherError<S::Error, U>>>),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
|
||||
enum PollResult {
|
||||
Continue,
|
||||
Pending,
|
||||
}
|
||||
|
||||
struct InnerDispatcher<S, T, U, Out>
|
||||
where
|
||||
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
|
||||
S::Error: 'static,
|
||||
S::Future: 'static,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
U: Encoder + Decoder,
|
||||
<U as Encoder>::Item: 'static,
|
||||
<U as Encoder>::Error: std::fmt::Debug,
|
||||
Out: Stream<Item = <U as Encoder>::Item> + Unpin,
|
||||
{
|
||||
service: S,
|
||||
sink: Option<Out>,
|
||||
state: FramedState<S, U>,
|
||||
framed: Framed<T, U>,
|
||||
rx: mpsc::Receiver<Result<<U as Encoder>::Item, S::Error>>,
|
||||
disconnect_timeout: u64,
|
||||
}
|
||||
|
||||
impl<S, T, U, Out> InnerDispatcher<S, T, U, Out>
|
||||
where
|
||||
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
|
||||
S::Error: 'static,
|
||||
S::Future: 'static,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
U: Decoder + Encoder,
|
||||
<U as Encoder>::Item: 'static,
|
||||
<U as Encoder>::Error: std::fmt::Debug,
|
||||
Out: Stream<Item = <U as Encoder>::Item> + Unpin,
|
||||
{
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> PollResult {
|
||||
loop {
|
||||
match self.service.poll_ready(cx) {
|
||||
Poll::Ready(Ok(_)) => {
|
||||
let item = match self.framed.next_item(cx) {
|
||||
Poll::Ready(Some(Ok(el))) => el,
|
||||
Poll::Ready(Some(Err(err))) => {
|
||||
log::trace!("Framed decode error");
|
||||
self.state = match err {
|
||||
Either::Left(err) => FramedState::Shutdown(Some(
|
||||
DispatcherError::Decoder(err),
|
||||
)),
|
||||
Either::Right(err) => FramedState::Shutdown(Some(
|
||||
DispatcherError::IoError(err),
|
||||
)),
|
||||
};
|
||||
return PollResult::Continue;
|
||||
}
|
||||
Poll::Pending => return PollResult::Pending,
|
||||
Poll::Ready(None) => {
|
||||
log::trace!("Client disconnected");
|
||||
self.state = FramedState::Shutdown(None);
|
||||
return PollResult::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
let tx = self.rx.sender();
|
||||
crate::rt::spawn(self.service.call(item).map(move |item| {
|
||||
let item = match item {
|
||||
Ok(Some(item)) => Ok(item),
|
||||
Err(err) => Err(err),
|
||||
_ => return,
|
||||
};
|
||||
let _ = tx.send(item);
|
||||
}));
|
||||
}
|
||||
Poll::Pending => return PollResult::Pending,
|
||||
Poll::Ready(Err(err)) => {
|
||||
self.state =
|
||||
FramedState::FlushAndStop(Some(DispatcherError::Service(err)));
|
||||
return PollResult::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// write to framed object
|
||||
fn poll_write(&mut self, cx: &mut Context<'_>) -> PollResult {
|
||||
loop {
|
||||
while !self.framed.is_write_buf_full() {
|
||||
match Pin::new(&mut self.rx).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(msg))) => {
|
||||
if let Err(err) = self.framed.write(msg) {
|
||||
log::trace!("Framed write error: {:?}", err);
|
||||
self.state = FramedState::Shutdown(Some(
|
||||
DispatcherError::Encoder(err),
|
||||
));
|
||||
return PollResult::Continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => {
|
||||
self.state = FramedState::FlushAndStop(Some(
|
||||
DispatcherError::Service(err),
|
||||
));
|
||||
return PollResult::Continue;
|
||||
}
|
||||
Poll::Ready(None) | Poll::Pending => {}
|
||||
}
|
||||
|
||||
if let Some(ref mut sink) = self.sink {
|
||||
match Pin::new(sink).poll_next(cx) {
|
||||
Poll::Ready(Some(msg)) => {
|
||||
if let Err(err) = self.framed.write(msg) {
|
||||
log::trace!("Framed write error from sink: {:?}", err);
|
||||
self.state = FramedState::Shutdown(Some(
|
||||
DispatcherError::Encoder(err),
|
||||
));
|
||||
return PollResult::Continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
let _ = self.sink.take();
|
||||
self.state = FramedState::FlushAndStop(None);
|
||||
return PollResult::Continue;
|
||||
}
|
||||
Poll::Pending => (),
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if !self.framed.is_write_buf_empty() {
|
||||
match self.framed.flush(cx) {
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Ok(_)) => (),
|
||||
Poll::Ready(Err(err)) => {
|
||||
debug!("Error sending data: {:?}", err);
|
||||
self.state =
|
||||
FramedState::Shutdown(Some(DispatcherError::IoError(err)));
|
||||
return PollResult::Continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
PollResult::Pending
|
||||
}
|
||||
|
||||
pub(super) fn poll(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), DispatcherError<S::Error, U>>> {
|
||||
loop {
|
||||
match self.state {
|
||||
FramedState::Processing => {
|
||||
let read = self.poll_read(cx);
|
||||
let write = self.poll_write(cx);
|
||||
if read == PollResult::Continue || write == PollResult::Continue {
|
||||
continue;
|
||||
} else {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
FramedState::FlushAndStop(ref mut err) => {
|
||||
// drain service responses
|
||||
match Pin::new(&mut self.rx).poll_next(cx) {
|
||||
Poll::Ready(Some(Ok(msg))) => {
|
||||
if let Err(err) = self.framed.write(msg) {
|
||||
log::trace!("Framed write message error: {:?}", err);
|
||||
self.state = FramedState::Shutdown(Some(
|
||||
DispatcherError::Encoder(err),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => {
|
||||
log::trace!("Sink poll error");
|
||||
self.state = FramedState::Shutdown(Some(err.into()));
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) | Poll::Pending => (),
|
||||
}
|
||||
|
||||
// flush io
|
||||
if !self.framed.is_write_buf_empty() {
|
||||
match self.framed.flush(cx) {
|
||||
Poll::Ready(Err(err)) => {
|
||||
debug!("Error sending data: {:?}", err);
|
||||
}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(_) => (),
|
||||
}
|
||||
};
|
||||
log::trace!("Framed flushed, shutdown");
|
||||
self.state = FramedState::Shutdown(err.take());
|
||||
}
|
||||
FramedState::Shutdown(ref mut err) => {
|
||||
return if self.service.poll_shutdown(cx, err.is_some()).is_ready() {
|
||||
let result = if let Some(err) = err.take() {
|
||||
if let DispatcherError::Service(_) = err {
|
||||
Err(err)
|
||||
} else {
|
||||
// no need for io shutdown because io error occured
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
};
|
||||
|
||||
// frame close, closes io WR side and waits for disconnect
|
||||
// on read side. we need disconnect timeout, because it
|
||||
// could hang forever.
|
||||
let pending = self.framed.close(cx).is_pending();
|
||||
if self.disconnect_timeout != 0 && pending {
|
||||
self.state = FramedState::ShutdownIo(
|
||||
delay_for(Duration::from_millis(
|
||||
self.disconnect_timeout,
|
||||
)),
|
||||
Some(result),
|
||||
);
|
||||
continue;
|
||||
} else {
|
||||
Poll::Ready(result)
|
||||
}
|
||||
} else {
|
||||
Poll::Pending
|
||||
};
|
||||
}
|
||||
FramedState::ShutdownIo(ref mut delay, ref mut err) => {
|
||||
if let Poll::Ready(res) = self.framed.close(cx) {
|
||||
return match err.take() {
|
||||
Some(Ok(_)) | None => {
|
||||
Poll::Ready(res.map_err(DispatcherError::IoError))
|
||||
}
|
||||
Some(Err(e)) => Poll::Ready(Err(e)),
|
||||
};
|
||||
} else {
|
||||
ready!(Pin::new(delay).poll(cx));
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use derive_more::Display;
|
||||
use futures::future::ok;
|
||||
use std::io;
|
||||
|
||||
use super::*;
|
||||
use crate::channel::mpsc;
|
||||
use crate::codec::{BytesCodec, Framed};
|
||||
use crate::rt::time::delay_for;
|
||||
use crate::testing::Io;
|
||||
|
||||
#[test]
|
||||
fn test_err() {
|
||||
#[derive(Debug, Display)]
|
||||
struct TestError;
|
||||
type T = DispatcherError<TestError, BytesCodec>;
|
||||
let err = T::Encoder(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatcherError::Encoder"));
|
||||
assert!(format!("{}", err).contains("Custom"));
|
||||
let err = T::Decoder(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatcherError::Decoder"));
|
||||
assert!(format!("{}", err).contains("Custom"));
|
||||
let err = T::from(TestError);
|
||||
assert!(format!("{:?}", err).contains("DispatcherError::Service"));
|
||||
assert_eq!(format!("{}", err), "TestError");
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_basic() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let framed = Framed::new(server, BytesCodec);
|
||||
let disp = Dispatcher::new(
|
||||
framed,
|
||||
crate::fn_service(|msg: BytesMut| async move {
|
||||
delay_for(Duration::from_millis(50)).await;
|
||||
Ok::<_, ()>(Some(msg.freeze()))
|
||||
}),
|
||||
);
|
||||
crate::rt::spawn(disp.map(|_| ()));
|
||||
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
|
||||
client.close().await;
|
||||
assert!(client.is_server_dropped());
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_sink() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let (tx, rx) = mpsc::channel();
|
||||
let framed = Framed::new(server, BytesCodec);
|
||||
let disp = Dispatcher::with(
|
||||
framed,
|
||||
Some(rx),
|
||||
crate::fn_service(|msg: BytesMut| ok::<_, ()>(Some(msg.freeze()))),
|
||||
)
|
||||
.disconnect_timeout(25);
|
||||
crate::rt::spawn(disp.map(|_| ()));
|
||||
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
|
||||
assert!(tx.send(Bytes::from_static(b"test")).is_ok());
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"test"));
|
||||
|
||||
drop(tx);
|
||||
delay_for(Duration::from_millis(200)).await;
|
||||
assert!(client.is_server_dropped());
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_err_in_service() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(0);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let mut framed = Framed::new(server, BytesCodec);
|
||||
framed.write_buf().extend(b"GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let disp = Dispatcher::new(
|
||||
framed,
|
||||
crate::fn_service(|_: BytesMut| async { Err::<Option<Bytes>, _>(()) }),
|
||||
);
|
||||
crate::rt::spawn(disp.map(|_| ()));
|
||||
|
||||
let buf = client.read_any();
|
||||
assert_eq!(buf, Bytes::from_static(b""));
|
||||
delay_for(Duration::from_millis(25)).await;
|
||||
|
||||
// buffer should be flushed
|
||||
client.remote_buffer_cap(1024);
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
|
||||
// write side must be closed, dispatcher waiting for read side to close
|
||||
assert!(client.is_closed());
|
||||
|
||||
// close read side
|
||||
client.close().await;
|
||||
assert!(client.is_server_dropped());
|
||||
}
|
||||
}
|
|
@ -2,7 +2,6 @@ pub mod buffer;
|
|||
pub mod counter;
|
||||
pub mod either;
|
||||
mod extensions;
|
||||
pub mod framed;
|
||||
pub mod inflight;
|
||||
pub mod keepalive;
|
||||
pub mod stream;
|
||||
|
|
|
@ -156,11 +156,9 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::BytesMut;
|
||||
use futures::future::ok;
|
||||
use futures::StreamExt;
|
||||
use std::cell::Cell;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
use bytestring::ByteString;
|
||||
use futures::{future::ok, StreamExt};
|
||||
use std::{cell::Cell, rc::Rc, time::Duration};
|
||||
|
||||
use super::*;
|
||||
use crate::channel::mpsc;
|
||||
|
@ -183,7 +181,7 @@ mod tests {
|
|||
encoder,
|
||||
crate::fn_service(move |_| {
|
||||
counter2.set(counter2.get() + 1);
|
||||
ok(Some(ws::Message::Text("test".to_string())))
|
||||
ok(Some(ws::Message::Text(ByteString::from_static("test"))))
|
||||
}),
|
||||
);
|
||||
crate::rt::spawn(disp.map(|_| ()));
|
||||
|
@ -191,7 +189,7 @@ mod tests {
|
|||
let mut buf = BytesMut::new();
|
||||
let codec = ws::Codec::new().client_mode();
|
||||
codec
|
||||
.encode(ws::Message::Text("test".to_string()), &mut buf)
|
||||
.encode(ws::Message::Text(ByteString::from_static("test")), &mut buf)
|
||||
.unwrap();
|
||||
tx.send(Ok::<_, ()>(buf.split().freeze())).unwrap();
|
||||
|
||||
|
|
|
@ -940,7 +940,7 @@ impl TestServer {
|
|||
{
|
||||
let url = self.url(path);
|
||||
let connect = self.client.ws(url).connect();
|
||||
connect.await.map(|(_, framed)| framed)
|
||||
connect.await.map(|ws| ws.into_inner().1)
|
||||
}
|
||||
|
||||
/// Connect to a websocket server
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use bytes::{Bytes, BytesMut};
|
||||
use bytestring::ByteString;
|
||||
use std::cell::Cell;
|
||||
|
||||
use crate::codec::{Decoder, Encoder};
|
||||
|
@ -11,7 +12,7 @@ use super::ProtocolError;
|
|||
#[derive(Debug, PartialEq)]
|
||||
pub enum Message {
|
||||
/// Text message
|
||||
Text(String),
|
||||
Text(ByteString),
|
||||
/// Binary message
|
||||
Binary(Bytes),
|
||||
/// Continuation
|
||||
|
@ -60,7 +61,7 @@ pub struct Codec {
|
|||
bitflags::bitflags! {
|
||||
struct Flags: u8 {
|
||||
const SERVER = 0b0000_0001;
|
||||
const CONTINUATION = 0b0000_0010;
|
||||
const R_CONTINUATION = 0b0000_0010;
|
||||
const W_CONTINUATION = 0b0000_0100;
|
||||
}
|
||||
}
|
||||
|
@ -222,7 +223,7 @@ impl Decoder for Codec {
|
|||
if !finished {
|
||||
return match opcode {
|
||||
OpCode::Continue => {
|
||||
if self.flags.get().contains(Flags::CONTINUATION) {
|
||||
if self.flags.get().contains(Flags::R_CONTINUATION) {
|
||||
Ok(Some(Frame::Continuation(Item::Continue(
|
||||
payload
|
||||
.map(|pl| pl.freeze())
|
||||
|
@ -233,8 +234,8 @@ impl Decoder for Codec {
|
|||
}
|
||||
}
|
||||
OpCode::Binary => {
|
||||
if !self.flags.get().contains(Flags::CONTINUATION) {
|
||||
self.insert_flags(Flags::CONTINUATION);
|
||||
if !self.flags.get().contains(Flags::R_CONTINUATION) {
|
||||
self.insert_flags(Flags::R_CONTINUATION);
|
||||
Ok(Some(Frame::Continuation(Item::FirstBinary(
|
||||
payload
|
||||
.map(|pl| pl.freeze())
|
||||
|
@ -245,8 +246,8 @@ impl Decoder for Codec {
|
|||
}
|
||||
}
|
||||
OpCode::Text => {
|
||||
if !self.flags.get().contains(Flags::CONTINUATION) {
|
||||
self.insert_flags(Flags::CONTINUATION);
|
||||
if !self.flags.get().contains(Flags::R_CONTINUATION) {
|
||||
self.insert_flags(Flags::R_CONTINUATION);
|
||||
Ok(Some(Frame::Continuation(Item::FirstText(
|
||||
payload
|
||||
.map(|pl| pl.freeze())
|
||||
|
@ -265,8 +266,8 @@ impl Decoder for Codec {
|
|||
|
||||
match opcode {
|
||||
OpCode::Continue => {
|
||||
if self.flags.get().contains(Flags::CONTINUATION) {
|
||||
self.remove_flags(Flags::CONTINUATION);
|
||||
if self.flags.get().contains(Flags::R_CONTINUATION) {
|
||||
self.remove_flags(Flags::R_CONTINUATION);
|
||||
Ok(Some(Frame::Continuation(Item::Last(
|
||||
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
|
||||
))))
|
||||
|
|
|
@ -11,13 +11,24 @@ mod codec;
|
|||
mod frame;
|
||||
mod mask;
|
||||
mod proto;
|
||||
mod sink;
|
||||
mod stream;
|
||||
|
||||
pub use self::codec::{Codec, Frame, Item, Message};
|
||||
pub use self::frame::Parser;
|
||||
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
|
||||
pub use self::sink::WsSink;
|
||||
pub use self::stream::{StreamDecoder, StreamEncoder};
|
||||
|
||||
/// Websocket service errors
|
||||
#[derive(Debug, Display)]
|
||||
pub enum WsError<E> {
|
||||
Service(E),
|
||||
KeepAlive,
|
||||
Protocol(ProtocolError),
|
||||
Io(io::Error),
|
||||
}
|
||||
|
||||
/// Websocket protocol errors
|
||||
#[derive(Debug, Display, From)]
|
||||
pub enum ProtocolError {
|
||||
|
@ -48,7 +59,4 @@ pub enum ProtocolError {
|
|||
/// Unknown continuation fragment
|
||||
#[display(fmt = "Unknown continuation fragment.")]
|
||||
ContinuationFragment(OpCode),
|
||||
/// Io error
|
||||
#[display(fmt = "io error: {}", _0)]
|
||||
Io(io::Error),
|
||||
}
|
||||
|
|
29
ntex/src/ws/sink.rs
Normal file
29
ntex/src/ws/sink.rs
Normal file
|
@ -0,0 +1,29 @@
|
|||
use std::{future::Future, rc::Rc};
|
||||
|
||||
use crate::framed::State;
|
||||
use crate::ws;
|
||||
|
||||
pub struct WsSink(Rc<WsSinkInner>);
|
||||
|
||||
struct WsSinkInner {
|
||||
state: State,
|
||||
codec: ws::Codec,
|
||||
}
|
||||
|
||||
impl WsSink {
|
||||
pub(crate) fn new(state: State, codec: ws::Codec) -> Self {
|
||||
Self(Rc::new(WsSinkInner { state, codec }))
|
||||
}
|
||||
|
||||
pub fn send(
|
||||
&self,
|
||||
item: ws::Message,
|
||||
) -> impl Future<Output = Result<(), ws::ProtocolError>> {
|
||||
let inner = self.0.clone();
|
||||
|
||||
async move {
|
||||
inner.state.write_item(item, &inner.codec)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
|
@ -176,6 +176,7 @@ where
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytestring::ByteString;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
||||
use super::*;
|
||||
|
@ -189,10 +190,10 @@ mod tests {
|
|||
let mut buf = BytesMut::new();
|
||||
let codec = Codec::new().client_mode();
|
||||
codec
|
||||
.encode(Message::Text("test1".to_string()), &mut buf)
|
||||
.encode(Message::Text(ByteString::from_static("test1")), &mut buf)
|
||||
.unwrap();
|
||||
codec
|
||||
.encode(Message::Text("test2".to_string()), &mut buf)
|
||||
.encode(Message::Text(ByteString::from_static("test2")), &mut buf)
|
||||
.unwrap();
|
||||
|
||||
tx.send(Ok::<_, ()>(buf.split().freeze())).unwrap();
|
||||
|
@ -214,7 +215,7 @@ mod tests {
|
|||
let mut encoder = StreamEncoder::new(tx);
|
||||
|
||||
encoder
|
||||
.send(Ok::<_, ()>(Message::Text("test".to_string())))
|
||||
.send(Ok::<_, ()>(Message::Text(ByteString::from_static("test"))))
|
||||
.await
|
||||
.unwrap();
|
||||
encoder.flush().await.unwrap();
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::io;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::future::ok;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use bytestring::ByteString;
|
||||
use futures::{future::ok, SinkExt, StreamExt};
|
||||
|
||||
use ntex::framed::{DispatchItem, Dispatcher, State};
|
||||
use ntex::http::test::server as test_server;
|
||||
|
@ -17,9 +17,9 @@ async fn ws_service(
|
|||
let msg = match msg {
|
||||
DispatchItem::Item(msg) => match msg {
|
||||
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
|
||||
ws::Frame::Text(text) => {
|
||||
ws::Message::Text(String::from_utf8(Vec::from(text.as_ref())).unwrap())
|
||||
}
|
||||
ws::Frame::Text(text) => ws::Message::Text(
|
||||
String::from_utf8(Vec::from(text.as_ref())).unwrap().into(),
|
||||
),
|
||||
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
|
||||
ws::Frame::Close(reason) => ws::Message::Close(reason),
|
||||
_ => ws::Message::Close(None),
|
||||
|
@ -65,7 +65,7 @@ async fn test_simple() {
|
|||
// client service
|
||||
let mut framed = srv.ws().await.unwrap();
|
||||
framed
|
||||
.send(ws::Message::Text("text".to_string()))
|
||||
.send(ws::Message::Text(ByteString::from_static("text")))
|
||||
.await
|
||||
.unwrap();
|
||||
let item = framed.next().await.unwrap().unwrap();
|
||||
|
|
|
@ -3,6 +3,7 @@ use std::task::{Context, Poll};
|
|||
use std::{cell::Cell, io, marker::PhantomData, pin::Pin};
|
||||
|
||||
use bytes::Bytes;
|
||||
use bytestring::ByteString;
|
||||
use futures::{future, Future, SinkExt, StreamExt};
|
||||
|
||||
use ntex::codec::{AsyncRead, AsyncWrite};
|
||||
|
@ -71,7 +72,7 @@ async fn service(
|
|||
DispatchItem::Item(msg) => match msg {
|
||||
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
|
||||
ws::Frame::Text(text) => {
|
||||
ws::Message::Text(String::from_utf8_lossy(&text).to_string())
|
||||
ws::Message::Text(String::from_utf8_lossy(&text).as_ref().into())
|
||||
}
|
||||
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
|
||||
ws::Frame::Continuation(item) => ws::Message::Continuation(item),
|
||||
|
@ -102,7 +103,7 @@ async fn test_simple() {
|
|||
// client service
|
||||
let mut framed = srv.ws().await.unwrap();
|
||||
framed
|
||||
.send(ws::Message::Text("text".to_string()))
|
||||
.send(ws::Message::Text(ByteString::from_static("text")))
|
||||
.await
|
||||
.unwrap();
|
||||
let (item, mut framed) = framed.into_future().await;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use std::io;
|
||||
|
||||
use bytes::Bytes;
|
||||
use bytestring::ByteString;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
||||
use ntex::service::{fn_factory_with_config, fn_service};
|
||||
|
@ -10,7 +11,7 @@ async fn service(msg: ws::Frame) -> Result<Option<ws::Message>, io::Error> {
|
|||
let msg = match msg {
|
||||
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
|
||||
ws::Frame::Text(text) => {
|
||||
ws::Message::Text(String::from_utf8_lossy(&text).to_string())
|
||||
ws::Message::Text(String::from_utf8_lossy(&text).as_ref().into())
|
||||
}
|
||||
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
|
||||
ws::Frame::Close(reason) => ws::Message::Close(reason),
|
||||
|
@ -39,7 +40,7 @@ async fn web_ws() {
|
|||
// client service
|
||||
let mut framed = srv.ws().await.unwrap();
|
||||
framed
|
||||
.send(ws::Message::Text("text".to_string()))
|
||||
.send(ws::Message::Text(ByteString::from_static("text")))
|
||||
.await
|
||||
.unwrap();
|
||||
let item = framed.next().await.unwrap().unwrap();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue