Refactor write back-pressure (#39)

* refactor error handling

* refactor write back-pressure
This commit is contained in:
Nikolay Kim 2021-01-25 17:29:44 +06:00 committed by GitHub
parent f0fe2bbc59
commit 26543a4247
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 567 additions and 899 deletions

View file

@ -1,5 +1,11 @@
# Changes
## [0.2.0-b.7] - 2021-01-25
* Fix error handling for framed disaptcher
* Refactor framed disaptcher write back-pressure support
## [0.2.0-b.6] - 2021-01-24
* http: Pass io stream to upgrade handler

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.2.0-b.6"
version = "0.2.0-b.7"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"

View file

@ -1,12 +1,9 @@
//! Framed transport dispatcher
use std::task::{Context, Poll};
use std::{
cell::Cell, cell::RefCell, fmt, future::Future, pin::Pin, rc::Rc, time::Duration,
time::Instant,
};
use std::{cell::Cell, cell::RefCell, pin::Pin, rc::Rc, time::Duration, time::Instant};
use either::Either;
use futures::FutureExt;
use futures::{ready, Future, FutureExt};
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
use crate::framed::{DispatchItem, ReadTask, State, Timer, WriteTask};
@ -29,7 +26,7 @@ pin_project_lite::pin_project! {
service: S,
inner: DispatcherInner<S, U>,
#[pin]
response: Option<S::Future>,
fut: Option<S::Future>,
}
}
@ -38,11 +35,12 @@ where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Encoder + Decoder,
{
st: DispatcherState,
st: Cell<DispatcherState>,
state: State,
timer: Timer,
updated: Instant,
keepalive_timeout: u16,
ka_timeout: u16,
ka_updated: Cell<Instant>,
error: Cell<Option<S::Error>>,
shared: Rc<DispatcherShared<S, U>>,
}
@ -59,16 +57,24 @@ where
#[derive(Copy, Clone, Debug)]
enum DispatcherState {
Processing,
WrEnabled,
WrWaitReady,
Stop,
Shutdown,
}
pub(crate) enum DispatcherError<S, U> {
enum DispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
}
enum PollService<U: Encoder + Decoder> {
Item(DispatchItem<U>),
ServiceError,
Ready,
}
impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
fn from(err: Either<S, U>) -> Self {
match err {
@ -78,19 +84,6 @@ impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
}
}
impl<E1, E2: fmt::Debug> DispatcherError<E1, E2> {
fn convert<U>(self) -> Option<DispatchItem<U>>
where
U: Encoder<Error = E2> + Decoder,
{
match self {
DispatcherError::KeepAlive => Some(DispatchItem::KeepAliveTimeout),
DispatcherError::Encoder(err) => Some(DispatchItem::EncoderError(err)),
DispatcherError::Service(_) => None,
}
}
}
impl<S, U> Dispatcher<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
@ -125,21 +118,22 @@ where
timer: Timer,
) -> Self {
let updated = timer.now();
let keepalive_timeout: u16 = 30;
let ka_timeout: u16 = 30;
// register keepalive timer
let expire = updated + Duration::from_secs(keepalive_timeout as u64);
let expire = updated + Duration::from_secs(ka_timeout as u64);
timer.register(expire, expire, &state);
Dispatcher {
service: service.into_service(),
response: None,
fut: None,
inner: DispatcherInner {
state,
timer,
updated,
keepalive_timeout,
st: DispatcherState::Processing,
ka_timeout,
ka_updated: Cell::new(updated),
error: Cell::new(None),
st: Cell::new(DispatcherState::Processing),
shared: Rc::new(DispatcherShared {
codec,
error: Cell::new(None),
@ -156,15 +150,15 @@ where
/// By default keep-alive timeout is set to 30 seconds.
pub fn keepalive_timeout(mut self, timeout: u16) -> Self {
// register keepalive timer
let prev = self.inner.updated
+ Duration::from_secs(self.inner.keepalive_timeout as u64);
let prev = self.inner.ka_updated.get() + self.inner.ka();
if timeout == 0 {
self.inner.timer.unregister(prev, &self.inner.state);
} else {
let expire = self.inner.updated + Duration::from_secs(timeout as u64);
let expire =
self.inner.ka_updated.get() + Duration::from_secs(timeout as u64);
self.inner.timer.register(expire, prev, &self.inner.state);
}
self.inner.keepalive_timeout = timeout;
self.inner.ka_timeout = timeout;
self
}
@ -191,71 +185,14 @@ where
U: Encoder + Decoder,
<U as Encoder>::Item: 'static,
{
fn handle_result(
&self,
item: Result<S::Response, S::Error>,
state: &State,
wake: bool,
) {
fn handle_result(&self, item: Result<S::Response, S::Error>, state: &State) {
self.inflight.set(self.inflight.get() - 1);
if let Err(err) = state.write_result(item, &self.codec) {
self.error.set(Some(err.into()));
}
if wake {
state.dsp_wake_task()
}
}
}
impl<S, U> DispatcherInner<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Decoder + Encoder,
{
fn take_error(&self) -> Option<DispatchItem<U>> {
// check for errors
self.shared
.error
.take()
.and_then(|err| err.convert())
.or_else(|| self.state.take_io_error().map(DispatchItem::IoError))
}
/// check keepalive timeout
fn check_keepalive(&self) {
if self.state.is_keepalive() {
log::trace!("keepalive timeout");
if let Some(err) = self.shared.error.take() {
self.shared.error.set(Some(err));
} else {
self.shared.error.set(Some(DispatcherError::KeepAlive));
}
self.state.dsp_mark_stopped();
}
}
/// update keep-alive timer
fn update_keepalive(&mut self) {
if self.keepalive_timeout != 0 {
let updated = self.timer.now();
if updated != self.updated {
let ka = Duration::from_secs(self.keepalive_timeout as u64);
self.timer
.register(updated + ka, self.updated + ka, &self.state);
self.updated = updated;
}
}
}
/// unregister keep-alive timer
fn unregister_keepalive(&self) {
if self.keepalive_timeout != 0 {
self.timer.unregister(
self.updated + Duration::from_secs(self.keepalive_timeout as u64),
&self.state,
);
match state.write_result(item, &self.codec) {
Ok(true) => (),
Ok(false) => state.enable_write_backpressure(),
Err(err) => self.error.set(Some(err.into())),
}
state.dsp_wake_task();
}
}
@ -269,181 +206,311 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
let slf = &this.inner;
let state = &slf.state;
// handle service response future
if let Some(fut) = this.response.as_mut().as_pin_mut() {
if let Some(fut) = this.fut.as_mut().as_pin_mut() {
match fut.poll(cx) {
Poll::Pending => (),
Poll::Ready(item) => {
this.inner
.shared
.handle_result(item, &this.inner.state, false);
this.response.set(None);
this.fut.set(None);
slf.shared.inflight.set(slf.shared.inflight.get() - 1);
let _ = slf.handle_result(item, cx);
}
}
}
match this.inner.st {
DispatcherState::Processing => {
loop {
match this.service.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
let mut retry = false;
loop {
match slf.st.get() {
DispatcherState::WrEnabled => {
let item = match ready!(slf.poll_service(&this.service, cx)) {
PollService::Ready => {
slf.st.set(DispatcherState::WrWaitReady);
DispatchItem::WBackPressureEnabled
}
PollService::Item(item) => item,
PollService::ServiceError => continue,
};
// service is ready, wake io read task
this.inner.state.dsp_restart_read_task();
// check keepalive timeout
this.inner.check_keepalive();
let item = if this.inner.state.is_dsp_stopped() {
log::trace!("dispatcher is instructed to stop");
// unregister keep-alive timer
this.inner.unregister_keepalive();
// check for errors
retry = true;
this.inner.st = DispatcherState::Stop;
this.inner.take_error()
} else {
// decode incoming bytes stream
if this.inner.state.is_read_ready() {
let item = this
.inner
.state
.decode_item(&this.inner.shared.codec);
match item {
Ok(Some(el)) => {
this.inner.update_keepalive();
Some(DispatchItem::Item(el))
}
Ok(None) => {
log::trace!("not enough data to decode next frame, register dispatch task");
this.inner
.state
.dsp_read_more_data(cx.waker());
return Poll::Pending;
}
Err(err) => {
retry = true;
this.inner.st = DispatcherState::Stop;
this.inner.unregister_keepalive();
Some(DispatchItem::DecoderError(err))
}
}
} else {
this.inner.state.dsp_register_task(cx.waker());
return Poll::Pending;
}
};
// call service
if let Some(item) = item {
// optimize first call
if this.response.is_none() {
this.response.set(Some(this.service.call(item)));
let res = this
.response
.as_mut()
.as_pin_mut()
.unwrap()
.poll(cx);
if let Poll::Ready(res) = res {
if let Err(err) = this
.inner
.state
.write_result(res, &this.inner.shared.codec)
{
this.inner
.shared
.error
.set(Some(err.into()));
}
this.response.set(None);
} else {
this.inner
.shared
.inflight
.set(this.inner.shared.inflight.get() + 1);
}
} else {
this.inner
.shared
.inflight
.set(this.inner.shared.inflight.get() + 1);
let st = this.inner.state.clone();
let shared = this.inner.shared.clone();
crate::rt::spawn(this.service.call(item).map(
move |item| {
shared.handle_result(item, &st, true);
},
));
}
// call service
if this.fut.is_none() {
// optimize first service call
this.fut.set(Some(this.service.call(item)));
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
Poll::Ready(res) => {
this.fut.set(None);
ready!(slf.handle_result(res, cx));
}
// run again
if retry {
return self.poll(cx);
Poll::Pending => {
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
}
}
Poll::Pending => {
// pause io read task
log::trace!("service is not ready, register dispatch task");
this.inner.state.dsp_service_not_ready(cx.waker());
return Poll::Pending;
}
Poll::Ready(Err(err)) => {
// handle service readiness error
log::trace!("service readiness check failed, stopping");
this.inner.st = DispatcherState::Stop;
this.inner.state.dsp_mark_stopped();
this.inner
.shared
.error
.set(Some(DispatcherError::Service(err)));
this.inner.unregister_keepalive();
return self.poll(cx);
}
} else {
slf.spawn_service_call(this.service.call(item));
}
}
}
// drain service responses
DispatcherState::Stop => {
// service may relay on poll_ready for response results
let _ = this.service.poll_ready(cx);
DispatcherState::WrWaitReady => {
let item = match ready!(slf.poll_service(&this.service, cx)) {
PollService::Ready => {
if state.is_write_backpressure_disabled() {
slf.st.set(DispatcherState::Processing);
DispatchItem::WBackPressureDisabled
} else {
return Poll::Pending;
}
}
PollService::Item(item) => item,
PollService::ServiceError => continue,
};
if this.inner.shared.inflight.get() == 0 {
this.inner.state.shutdown_io();
this.inner.st = DispatcherState::Shutdown;
self.poll(cx)
} else {
this.inner.state.dsp_register_task(cx.waker());
Poll::Pending
}
}
// shutdown service
DispatcherState::Shutdown => {
let err = this.inner.shared.error.take();
if this.service.poll_shutdown(cx, err.is_some()).is_ready() {
log::trace!("service shutdown is completed, stop");
Poll::Ready(if let Some(DispatcherError::Service(err)) = err {
Err(err)
// call service
if this.fut.is_none() {
// optimize first service call
this.fut.set(Some(this.service.call(item)));
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
Poll::Ready(res) => {
this.fut.set(None);
ready!(slf.handle_result(res, cx));
}
Poll::Pending => {
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
}
}
} else {
Ok(())
})
} else {
this.inner.shared.error.set(err);
Poll::Pending
slf.spawn_service_call(this.service.call(item));
}
}
DispatcherState::Processing => {
let item = match ready!(slf.poll_service(&this.service, cx)) {
PollService::Ready => {
if state.is_write_backpressure_enabled() {
// instruct write task to notify dispatcher when data is flushed
state.dsp_enable_write_backpressure(cx.waker());
slf.st.set(DispatcherState::WrWaitReady);
DispatchItem::WBackPressureEnabled
} else if state.is_read_ready() {
// decode incoming bytes if buffer is ready
match state.decode_item(&slf.shared.codec) {
Ok(Some(el)) => {
slf.update_keepalive();
DispatchItem::Item(el)
}
Ok(None) => {
log::trace!("not enough data to decode next frame, register dispatch task");
state.dsp_read_more_data(cx.waker());
return Poll::Pending;
}
Err(err) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::DecoderError(err)
}
}
} else {
// no new events
state.dsp_register_task(cx.waker());
return Poll::Pending;
}
}
PollService::Item(item) => item,
PollService::ServiceError => continue,
};
// call service
if this.fut.is_none() {
// optimize first service call
this.fut.set(Some(this.service.call(item)));
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
Poll::Ready(res) => {
this.fut.set(None);
ready!(slf.handle_result(res, cx));
}
Poll::Pending => {
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
}
}
} else {
slf.spawn_service_call(this.service.call(item));
}
}
// drain service responses
DispatcherState::Stop => {
// service may relay on poll_ready for response results
let _ = this.service.poll_ready(cx);
if slf.shared.inflight.get() == 0 {
slf.st.set(DispatcherState::Shutdown);
state.shutdown_io();
} else {
state.dsp_register_task(cx.waker());
return Poll::Pending;
}
}
// shutdown service
DispatcherState::Shutdown => {
let err = slf.error.take();
return if this.service.poll_shutdown(cx, err.is_some()).is_ready() {
log::trace!("service shutdown is completed, stop");
Poll::Ready(if let Some(err) = err {
Err(err)
} else {
Ok(())
})
} else {
slf.error.set(err);
Poll::Pending
};
}
}
}
}
}
impl<S, U> DispatcherInner<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + 'static,
{
/// spawn service call
fn spawn_service_call(&self, fut: S::Future) {
self.shared.inflight.set(self.shared.inflight.get() + 1);
let st = self.state.clone();
let shared = self.shared.clone();
crate::rt::spawn(fut.map(move |item| shared.handle_result(item, &st)));
}
fn handle_result(
&self,
item: Result<Option<<U as Encoder>::Item>, S::Error>,
cx: &mut Context<'_>,
) -> Poll<()> {
match self.state.write_result(item, &self.shared.codec) {
Ok(true) => (),
Ok(false) => {
// instruct write task to notify dispatcher when data is flushed
self.state.dsp_enable_write_backpressure(cx.waker());
self.st.set(DispatcherState::WrEnabled);
return Poll::Pending;
}
Err(Either::Left(err)) => {
self.error.set(Some(err));
}
Err(Either::Right(err)) => {
self.shared.error.set(Some(DispatcherError::Encoder(err)))
}
}
Poll::Ready(())
}
fn poll_service(&self, srv: &S, cx: &mut Context<'_>) -> Poll<PollService<U>> {
match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
// service is ready, wake io read task
self.state.dsp_restart_read_task();
// check keepalive timeout
self.check_keepalive();
// check for errors
Poll::Ready(if let Some(err) = self.shared.error.take() {
log::trace!("error occured, stopping dispatcher");
self.unregister_keepalive();
self.st.set(DispatcherState::Stop);
match err {
DispatcherError::KeepAlive => {
PollService::Item(DispatchItem::KeepAliveTimeout)
}
DispatcherError::Encoder(err) => {
PollService::Item(DispatchItem::EncoderError(err))
}
DispatcherError::Service(err) => {
self.error.set(Some(err));
PollService::ServiceError
}
}
} else if self.state.is_dsp_stopped() {
log::trace!("dispatcher is instructed to stop");
self.unregister_keepalive();
self.st.set(DispatcherState::Stop);
// get io error
if let Some(err) = self.state.take_io_error() {
PollService::Item(DispatchItem::IoError(err))
} else {
PollService::ServiceError
}
} else {
PollService::Ready
})
}
// pause io read task
Poll::Pending => {
log::trace!("service is not ready, register dispatch task");
self.state.dsp_service_not_ready(cx.waker());
Poll::Pending
}
// handle service readiness error
Poll::Ready(Err(err)) => {
log::trace!("service readiness check failed, stopping");
self.st.set(DispatcherState::Stop);
self.error.set(Some(err));
self.unregister_keepalive();
Poll::Ready(PollService::ServiceError)
}
}
}
fn ka(&self) -> Duration {
Duration::from_secs(self.ka_timeout as u64)
}
fn ka_enabled(&self) -> bool {
self.ka_timeout > 0
}
/// check keepalive timeout
fn check_keepalive(&self) {
if self.state.is_keepalive() {
log::trace!("keepalive timeout");
if let Some(err) = self.shared.error.take() {
self.shared.error.set(Some(err));
} else {
self.shared.error.set(Some(DispatcherError::KeepAlive));
}
}
}
/// update keep-alive timer
fn update_keepalive(&self) {
if self.ka_enabled() {
let updated = self.timer.now();
if updated != self.ka_updated.get() {
let ka = self.ka();
self.timer.register(
updated + ka,
self.ka_updated.get() + ka,
&self.state,
);
self.ka_updated.set(updated);
}
}
}
/// unregister keep-alive timer
fn unregister_keepalive(&self) {
if self.ka_enabled() {
self.timer
.unregister(self.ka_updated.get() + self.ka(), &self.state);
}
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
@ -473,8 +540,8 @@ mod tests {
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
let timer = Timer::default();
let keepalive_timeout = 30;
let updated = timer.now();
let ka_timeout = 30;
let ka_updated = timer.now();
let state = State::new();
let io = Rc::new(RefCell::new(io));
let shared = Rc::new(DispatcherShared {
@ -489,14 +556,15 @@ mod tests {
(
Dispatcher {
service: service.into_service(),
response: None,
fut: None,
inner: DispatcherInner {
shared,
timer,
updated,
keepalive_timeout,
ka_timeout,
ka_updated: Cell::new(ka_updated),
state: state.clone(),
st: DispatcherState::Processing,
error: Cell::new(None),
st: Cell::new(DispatcherState::Processing),
},
},
state,

View file

@ -17,6 +17,10 @@ use crate::codec::{Decoder, Encoder};
/// Framed transport item
pub enum DispatchItem<U: Encoder + Decoder> {
Item(<U as Decoder>::Item),
/// Write back-pressure enabled
WBackPressureEnabled,
/// Write back-pressure disabled
WBackPressureDisabled,
/// Keep alive timeout
KeepAliveTimeout,
/// Decoder parse error
@ -37,6 +41,12 @@ where
DispatchItem::Item(ref item) => {
write!(fmt, "DispatchItem::Item({:?})", item)
}
DispatchItem::WBackPressureEnabled => {
write!(fmt, "DispatchItem::WBackPressureEnabled")
}
DispatchItem::WBackPressureDisabled => {
write!(fmt, "DispatchItem::WBackPressureDisabled")
}
DispatchItem::KeepAliveTimeout => {
write!(fmt, "DispatchItem::KeepAliveTimeout")
}
@ -52,3 +62,28 @@ where
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::BytesCodec;
#[test]
fn test_fmt() {
type T = DispatchItem<BytesCodec>;
let err = T::EncoderError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::Encoder"));
let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::Decoder"));
let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::IoError"));
assert!(format!("{:?}", T::WBackPressureEnabled)
.contains("DispatchItem::WBackPressureEnabled"));
assert!(format!("{:?}", T::WBackPressureDisabled)
.contains("DispatchItem::WBackPressureDisabled"));
assert!(format!("{:?}", T::KeepAliveTimeout)
.contains("DispatchItem::KeepAliveTimeout"));
}
}

View file

@ -10,29 +10,31 @@ use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}
use crate::framed::write::flush;
use crate::task::LocalWaker;
const HW: usize = 8 * 1024;
const HW: usize = 16 * 1024;
bitflags::bitflags! {
pub struct Flags: u16 {
const DSP_STOP = 0b0000_0001;
const DSP_KEEPALIVE = 0b0000_0010;
const DSP_STOP = 0b0000_0000_0001;
const DSP_KEEPALIVE = 0b0000_0000_0010;
/// io error occured
const IO_ERR = 0b0000_0100;
const IO_ERR = 0b0000_0000_0100;
/// stop io tasks
const IO_STOP = 0b0000_1000;
const IO_STOP = 0b0000_0000_1000;
/// shutdown io tasks
const IO_SHUTDOWN = 0b0001_0000;
const IO_SHUTDOWN = 0b0000_0001_0000;
/// pause io read
const RD_PAUSED = 0b0010_0000;
const RD_PAUSED = 0b0000_0010_0000;
/// new data is available
const RD_READY = 0b0100_0000;
const RD_READY = 0b0000_0100_0000;
/// write task is ready
const WR_READY = 0b0001_0000_0000;
/// write buffer is full
const WR_NOT_READY = 0b1000_0000;
const WR_NOT_READY = 0b0010_0000_0000;
const ST_DSP_ERR = 0b10000_0000;
const ST_DSP_ERR = 0b0001_0000_0000_0000;
}
}
@ -163,14 +165,43 @@ impl State {
self.0.flags.get().contains(Flags::RD_READY)
}
/// read task must be paused if service is not ready (RD_PAUSED)
pub(super) fn is_read_paused(&self) -> bool {
self.0.flags.get().contains(Flags::RD_PAUSED)
self.0.flags.get().intersects(Flags::RD_PAUSED)
}
#[inline]
/// Check if write buffer is ready
pub fn is_write_ready(&self) -> bool {
!self.0.flags.get().contains(Flags::WR_NOT_READY)
/// Check if write back-pressure is disabled
pub fn is_write_backpressure_disabled(&self) -> bool {
let mut flags = self.0.flags.get();
if flags.contains(Flags::WR_READY) {
flags.remove(Flags::WR_READY);
self.0.flags.set(flags);
true
} else {
false
}
}
#[inline]
/// Check if write back-pressure is enabled
pub fn is_write_backpressure_enabled(&self) -> bool {
let mut flags = self.0.flags.get();
if flags.contains(Flags::WR_READY) {
flags.remove(Flags::WR_READY);
self.0.flags.set(flags);
true
} else {
false
}
}
#[inline]
/// Enable write back-persurre
pub fn enable_write_backpressure(&self) {
let mut flags = self.0.flags.get();
flags.insert(Flags::WR_NOT_READY);
self.0.flags.set(flags);
}
#[inline]
@ -272,6 +303,7 @@ impl State {
let mut flags = self.0.flags.get();
if flags.contains(Flags::WR_NOT_READY) {
flags.remove(Flags::WR_NOT_READY);
flags.insert(Flags::WR_READY);
self.0.flags.set(flags);
self.0.dispatch_task.wake();
}
@ -298,16 +330,24 @@ impl State {
self.0.dispatch_task.register(waker);
}
#[inline]
/// Check if write buff is full
pub fn is_write_buf_full(&self) -> bool {
self.0.write_buf.borrow().len() >= HW
}
#[inline]
/// Wait until write task flushes data to socket
pub fn dsp_flush_write_data(&self, waker: &Waker) {
///
/// Write task must be waken up separately.
pub fn dsp_enable_write_backpressure(&self, waker: &Waker) {
let mut flags = self.0.flags.get();
flags.insert(Flags::WR_NOT_READY);
self.0.flags.set(flags);
self.0.write_task.wake();
self.0.dispatch_task.register(waker);
}
#[doc(hidden)]
#[inline]
/// Mark dispatcher as stopped
pub fn dsp_mark_stopped(&self) {

View file

@ -7,7 +7,7 @@ use crate::codec::{AsyncRead, AsyncWrite};
use crate::framed::State;
use crate::rt::time::{delay_for, Delay};
const HW: usize = 8 * 1024;
const HW: usize = 16 * 1024;
#[derive(Debug)]
enum IoWriteState {

View file

@ -1,20 +1,17 @@
//! Websockets client
use std::convert::TryFrom;
use std::net::SocketAddr;
use std::rc::Rc;
use std::{fmt, str};
use std::{convert::TryFrom, fmt, net::SocketAddr, rc::Rc, str};
#[cfg(feature = "cookie")]
use coo_kie::{Cookie, CookieJar};
use futures::Stream;
use futures::future::{err, ok, Either};
use crate::codec::{AsyncRead, AsyncWrite, Framed};
use crate::framed::{DispatchItem, Dispatcher, State};
use crate::http::error::HttpError;
use crate::http::header::{self, HeaderName, HeaderValue, AUTHORIZATION};
use crate::http::{ConnectionType, Payload, RequestHead, StatusCode, Uri};
use crate::rt::time::timeout;
use crate::service::{IntoService, Service};
use crate::util::framed::{Dispatcher, DispatcherError};
use crate::service::{apply_fn, IntoService, Service};
use crate::ws;
pub use crate::ws::{CloseCode, CloseReason, Frame, Message};
@ -221,9 +218,7 @@ impl WsRequest {
}
/// Complete request construction and connect to a websockets server.
pub async fn connect(
mut self,
) -> Result<(ClientResponse, Framed<BoxedSocket, ws::Codec>), WsClientError> {
pub async fn connect(mut self) -> Result<WsConnection, WsClientError> {
if let Some(e) = self.err.take() {
return Err(WsClientError::from(e));
}
@ -378,8 +373,8 @@ impl WsRequest {
return Err(WsClientError::MissingWebSocketAcceptHeader);
};
// response and ws framed
Ok((
// response and ws io
Ok(WsConnection::new(
ClientResponse::new(head, Payload::None),
framed.map_codec(|_| {
if server_mode {
@ -407,21 +402,70 @@ impl fmt::Debug for WsRequest {
}
}
/// Start client websockets service.
pub async fn start<Io, T, F, Rx>(
framed: Framed<Io, ws::Codec>,
rx: Rx,
service: F,
) -> Result<(), DispatcherError<T::Error, ws::Codec>>
pub struct WsConnection<Io = BoxedSocket> {
io: Io,
state: State,
codec: ws::Codec,
res: ClientResponse,
}
impl<Io> WsConnection<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = ws::Frame, Response = Option<ws::Message>>,
T::Error: 'static,
T::Future: 'static,
F: IntoService<T>,
Rx: Stream<Item = ws::Message> + Unpin + 'static,
{
Dispatcher::with(framed, Some(rx), service.into_service()).await
fn new(res: ClientResponse, framed: Framed<Io, ws::Codec>) -> Self {
let (io, codec, state) = State::from_framed(framed);
Self {
io,
codec,
state,
res,
}
}
/// Get ws sink
pub fn sink(&self) -> ws::WsSink {
ws::WsSink::new(self.state.clone(), self.codec.clone())
}
/// Get reference to response
pub fn response(&self) -> &ClientResponse {
&self.res
}
/// Start client websockets service.
pub async fn start<T, F, Rx>(self, service: F) -> Result<(), ws::WsError<T::Error>>
where
T: Service<Request = ws::Frame, Response = Option<ws::Message>> + 'static,
F: IntoService<T>,
{
let service = apply_fn(
service.into_service().map_err(ws::WsError::Service),
|req, srv| match req {
DispatchItem::Item(item) => Either::Left(srv.call(item)),
DispatchItem::WBackPressureEnabled
| DispatchItem::WBackPressureDisabled => Either::Right(ok(None)),
DispatchItem::KeepAliveTimeout => {
Either::Right(err(ws::WsError::KeepAlive))
}
DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => {
Either::Right(err(ws::WsError::Protocol(e)))
}
DispatchItem::IoError(e) => Either::Right(err(ws::WsError::Io(e))),
},
);
Dispatcher::new(self.io, self.codec, self.state, service, Default::default())
.await
}
/// Consumes the `WsConnection`, returning it'as underlying I/O framed object
/// and response.
pub fn into_inner(self) -> (ClientResponse, Framed<Io, ws::Codec>) {
let framed = self.state.into_framed(self.io, self.codec);
(self.res, framed)
}
}
#[cfg(test)]

View file

@ -9,11 +9,8 @@ pub use actix_threadpool::BlockingError;
pub use futures::channel::oneshot::Canceled;
pub use http::Error as HttpError;
use crate::codec::{Decoder, Encoder};
use crate::util::framed::DispatcherError;
use super::body::Body;
use super::response::Response;
use crate::http::body::Body;
use crate::http::response::Response;
/// Error that can be converted to `Response`
pub trait ResponseError: fmt::Display + fmt::Debug {
@ -60,14 +57,6 @@ impl ResponseError for io::Error {}
/// `InternalServerError` for `JsonError`
impl ResponseError for serde_json::error::Error {}
impl<E, U: Encoder + Decoder + 'static> ResponseError for DispatcherError<E, U>
where
E: fmt::Debug + fmt::Display + 'static,
<U as Encoder>::Error: fmt::Debug,
<U as Decoder>::Error: fmt::Debug,
{
}
/// A set of errors that can occur during parsing HTTP streams
#[derive(Debug, Display, From)]
pub enum ParseError {

View file

@ -402,7 +402,9 @@ where
*this.st = st;
}
WritePayloadStatus::Pause => {
this.inner.state.dsp_flush_write_data(cx.waker());
this.inner
.state
.dsp_enable_write_backpressure(cx.waker());
return Poll::Pending;
}
WritePayloadStatus::Continue => (),

View file

@ -328,7 +328,7 @@ impl TestServer {
{
let url = self.url(path);
let connect = self.client.ws(url).connect();
connect.await.map(|(_, framed)| framed)
connect.await.map(|ws| ws.into_inner().1)
}
/// Connect to a websocket server

View file

@ -1,554 +0,0 @@
//! Framed transport dispatcher
use std::{fmt, io, pin::Pin, task::Context, task::Poll, time::Duration};
use either::Either;
use futures::{ready, Future, FutureExt, Stream};
use log::debug;
use crate::channel::mpsc;
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
use crate::rt::time::{delay_for, Delay};
use crate::service::{IntoService, Service};
type Request<U> = <U as Decoder>::Item;
type Response<U> = <U as Encoder>::Item;
/// Framed transport errors
pub enum DispatcherError<E, U: Encoder + Decoder> {
/// Inner service error
Service(E),
/// Encoder parse error
Encoder(<U as Encoder>::Error),
/// Decoder parse error
Decoder(<U as Decoder>::Error),
/// Unexpected io error
IoError(io::Error),
}
impl<E, U: Encoder + Decoder> From<E> for DispatcherError<E, U> {
fn from(err: E) -> Self {
DispatcherError::Service(err)
}
}
impl<E, U: Encoder + Decoder> From<Either<E, io::Error>> for DispatcherError<E, U> {
fn from(err: Either<E, io::Error>) -> Self {
match err {
Either::Left(err) => DispatcherError::Service(err),
Either::Right(err) => DispatcherError::IoError(err),
}
}
}
impl<E, U: Encoder + Decoder> fmt::Debug for DispatcherError<E, U>
where
E: fmt::Debug,
<U as Encoder>::Error: fmt::Debug,
<U as Decoder>::Error: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
DispatcherError::Service(ref e) => {
write!(fmt, "DispatcherError::Service({:?})", e)
}
DispatcherError::Encoder(ref e) => {
write!(fmt, "DispatcherError::Encoder({:?})", e)
}
DispatcherError::Decoder(ref e) => {
write!(fmt, "DispatcherError::Decoder({:?})", e)
}
DispatcherError::IoError(ref e) => {
write!(fmt, "DispatcherError::IoError({:?})", e)
}
}
}
}
impl<E, U: Encoder + Decoder> fmt::Display for DispatcherError<E, U>
where
E: fmt::Display,
<U as Encoder>::Error: fmt::Debug,
<U as Decoder>::Error: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
DispatcherError::Service(ref e) => write!(fmt, "{}", e),
DispatcherError::Encoder(ref e) => write!(fmt, "{:?}", e),
DispatcherError::Decoder(ref e) => write!(fmt, "{:?}", e),
DispatcherError::IoError(ref e) => write!(fmt, "{}", e),
}
}
}
pin_project_lite::pin_project! {
/// FramedTransport - is a future that reads frames from Framed object
/// and pass then to the service.
pub struct Dispatcher<S, T, U, Out>
where
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead,
T: AsyncWrite,
T: Unpin,
U: Encoder,
U: Decoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
Out: Stream<Item = <U as Encoder>::Item>,
Out: Unpin,
{
inner: InnerDispatcher<S, T, U, Out>,
}
}
impl<S, T, U> Dispatcher<S, T, U, mpsc::Receiver<<U as Encoder>::Item>>
where
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead + AsyncWrite + Unpin,
U: Decoder + Encoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
{
/// Construct new `Dispatcher` instance
pub fn new<F: IntoService<S>>(framed: Framed<T, U>, service: F) -> Self {
Dispatcher {
inner: InnerDispatcher {
framed,
sink: None,
rx: mpsc::channel().1,
service: service.into_service(),
state: FramedState::Processing,
disconnect_timeout: 1000,
},
}
}
}
impl<S, T, U, In> Dispatcher<S, T, U, In>
where
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead + AsyncWrite + Unpin,
U: Decoder + Encoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
In: Stream<Item = <U as Encoder>::Item> + Unpin,
{
/// Construct new `Dispatcher` instance with outgoing messages stream.
pub fn with<F: IntoService<S>>(
framed: Framed<T, U>,
sink: Option<In>,
service: F,
) -> Self {
Dispatcher {
inner: InnerDispatcher {
framed,
sink,
rx: mpsc::channel().1,
service: service.into_service(),
state: FramedState::Processing,
disconnect_timeout: 1000,
},
}
}
/// Set connection disconnect timeout in milliseconds.
///
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
/// within this time, the connection get dropped.
///
/// To disable timeout set value to 0.
///
/// By default disconnect timeout is set to 1 seconds.
pub fn disconnect_timeout(mut self, val: u64) -> Self {
self.inner.disconnect_timeout = val;
self
}
}
impl<S, T, U, In> Future for Dispatcher<S, T, U, In>
where
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead + AsyncWrite + Unpin,
U: Decoder + Encoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
In: Stream<Item = <U as Encoder>::Item> + Unpin,
{
type Output = Result<(), DispatcherError<S::Error, U>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
enum FramedState<S: Service, U: Encoder + Decoder> {
Processing,
FlushAndStop(Option<DispatcherError<S::Error, U>>),
Shutdown(Option<DispatcherError<S::Error, U>>),
ShutdownIo(Delay, Option<Result<(), DispatcherError<S::Error, U>>>),
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
enum PollResult {
Continue,
Pending,
}
struct InnerDispatcher<S, T, U, Out>
where
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead + AsyncWrite + Unpin,
U: Encoder + Decoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
Out: Stream<Item = <U as Encoder>::Item> + Unpin,
{
service: S,
sink: Option<Out>,
state: FramedState<S, U>,
framed: Framed<T, U>,
rx: mpsc::Receiver<Result<<U as Encoder>::Item, S::Error>>,
disconnect_timeout: u64,
}
impl<S, T, U, Out> InnerDispatcher<S, T, U, Out>
where
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
T: AsyncRead + AsyncWrite + Unpin,
U: Decoder + Encoder,
<U as Encoder>::Item: 'static,
<U as Encoder>::Error: std::fmt::Debug,
Out: Stream<Item = <U as Encoder>::Item> + Unpin,
{
fn poll_read(&mut self, cx: &mut Context<'_>) -> PollResult {
loop {
match self.service.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
let item = match self.framed.next_item(cx) {
Poll::Ready(Some(Ok(el))) => el,
Poll::Ready(Some(Err(err))) => {
log::trace!("Framed decode error");
self.state = match err {
Either::Left(err) => FramedState::Shutdown(Some(
DispatcherError::Decoder(err),
)),
Either::Right(err) => FramedState::Shutdown(Some(
DispatcherError::IoError(err),
)),
};
return PollResult::Continue;
}
Poll::Pending => return PollResult::Pending,
Poll::Ready(None) => {
log::trace!("Client disconnected");
self.state = FramedState::Shutdown(None);
return PollResult::Continue;
}
};
let tx = self.rx.sender();
crate::rt::spawn(self.service.call(item).map(move |item| {
let item = match item {
Ok(Some(item)) => Ok(item),
Err(err) => Err(err),
_ => return,
};
let _ = tx.send(item);
}));
}
Poll::Pending => return PollResult::Pending,
Poll::Ready(Err(err)) => {
self.state =
FramedState::FlushAndStop(Some(DispatcherError::Service(err)));
return PollResult::Continue;
}
}
}
}
/// write to framed object
fn poll_write(&mut self, cx: &mut Context<'_>) -> PollResult {
loop {
while !self.framed.is_write_buf_full() {
match Pin::new(&mut self.rx).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => {
if let Err(err) = self.framed.write(msg) {
log::trace!("Framed write error: {:?}", err);
self.state = FramedState::Shutdown(Some(
DispatcherError::Encoder(err),
));
return PollResult::Continue;
}
continue;
}
Poll::Ready(Some(Err(err))) => {
self.state = FramedState::FlushAndStop(Some(
DispatcherError::Service(err),
));
return PollResult::Continue;
}
Poll::Ready(None) | Poll::Pending => {}
}
if let Some(ref mut sink) = self.sink {
match Pin::new(sink).poll_next(cx) {
Poll::Ready(Some(msg)) => {
if let Err(err) = self.framed.write(msg) {
log::trace!("Framed write error from sink: {:?}", err);
self.state = FramedState::Shutdown(Some(
DispatcherError::Encoder(err),
));
return PollResult::Continue;
}
continue;
}
Poll::Ready(None) => {
let _ = self.sink.take();
self.state = FramedState::FlushAndStop(None);
return PollResult::Continue;
}
Poll::Pending => (),
}
}
break;
}
if !self.framed.is_write_buf_empty() {
match self.framed.flush(cx) {
Poll::Pending => break,
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => {
debug!("Error sending data: {:?}", err);
self.state =
FramedState::Shutdown(Some(DispatcherError::IoError(err)));
return PollResult::Continue;
}
}
} else {
break;
}
}
PollResult::Pending
}
pub(super) fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), DispatcherError<S::Error, U>>> {
loop {
match self.state {
FramedState::Processing => {
let read = self.poll_read(cx);
let write = self.poll_write(cx);
if read == PollResult::Continue || write == PollResult::Continue {
continue;
} else {
return Poll::Pending;
}
}
FramedState::FlushAndStop(ref mut err) => {
// drain service responses
match Pin::new(&mut self.rx).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => {
if let Err(err) = self.framed.write(msg) {
log::trace!("Framed write message error: {:?}", err);
self.state = FramedState::Shutdown(Some(
DispatcherError::Encoder(err),
));
continue;
}
}
Poll::Ready(Some(Err(err))) => {
log::trace!("Sink poll error");
self.state = FramedState::Shutdown(Some(err.into()));
continue;
}
Poll::Ready(None) | Poll::Pending => (),
}
// flush io
if !self.framed.is_write_buf_empty() {
match self.framed.flush(cx) {
Poll::Ready(Err(err)) => {
debug!("Error sending data: {:?}", err);
}
Poll::Pending => return Poll::Pending,
Poll::Ready(_) => (),
}
};
log::trace!("Framed flushed, shutdown");
self.state = FramedState::Shutdown(err.take());
}
FramedState::Shutdown(ref mut err) => {
return if self.service.poll_shutdown(cx, err.is_some()).is_ready() {
let result = if let Some(err) = err.take() {
if let DispatcherError::Service(_) = err {
Err(err)
} else {
// no need for io shutdown because io error occured
return Poll::Ready(Err(err));
}
} else {
Ok(())
};
// frame close, closes io WR side and waits for disconnect
// on read side. we need disconnect timeout, because it
// could hang forever.
let pending = self.framed.close(cx).is_pending();
if self.disconnect_timeout != 0 && pending {
self.state = FramedState::ShutdownIo(
delay_for(Duration::from_millis(
self.disconnect_timeout,
)),
Some(result),
);
continue;
} else {
Poll::Ready(result)
}
} else {
Poll::Pending
};
}
FramedState::ShutdownIo(ref mut delay, ref mut err) => {
if let Poll::Ready(res) = self.framed.close(cx) {
return match err.take() {
Some(Ok(_)) | None => {
Poll::Ready(res.map_err(DispatcherError::IoError))
}
Some(Err(e)) => Poll::Ready(Err(e)),
};
} else {
ready!(Pin::new(delay).poll(cx));
return Poll::Ready(Ok(()));
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use bytes::{Bytes, BytesMut};
use derive_more::Display;
use futures::future::ok;
use std::io;
use super::*;
use crate::channel::mpsc;
use crate::codec::{BytesCodec, Framed};
use crate::rt::time::delay_for;
use crate::testing::Io;
#[test]
fn test_err() {
#[derive(Debug, Display)]
struct TestError;
type T = DispatcherError<TestError, BytesCodec>;
let err = T::Encoder(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatcherError::Encoder"));
assert!(format!("{}", err).contains("Custom"));
let err = T::Decoder(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatcherError::Decoder"));
assert!(format!("{}", err).contains("Custom"));
let err = T::from(TestError);
assert!(format!("{:?}", err).contains("DispatcherError::Service"));
assert_eq!(format!("{}", err), "TestError");
}
#[ntex_rt::test]
async fn test_basic() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let framed = Framed::new(server, BytesCodec);
let disp = Dispatcher::new(
framed,
crate::fn_service(|msg: BytesMut| async move {
delay_for(Duration::from_millis(50)).await;
Ok::<_, ()>(Some(msg.freeze()))
}),
);
crate::rt::spawn(disp.map(|_| ()));
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
client.close().await;
assert!(client.is_server_dropped());
}
#[ntex_rt::test]
async fn test_sink() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let (tx, rx) = mpsc::channel();
let framed = Framed::new(server, BytesCodec);
let disp = Dispatcher::with(
framed,
Some(rx),
crate::fn_service(|msg: BytesMut| ok::<_, ()>(Some(msg.freeze()))),
)
.disconnect_timeout(25);
crate::rt::spawn(disp.map(|_| ()));
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
assert!(tx.send(Bytes::from_static(b"test")).is_ok());
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test"));
drop(tx);
delay_for(Duration::from_millis(200)).await;
assert!(client.is_server_dropped());
}
#[ntex_rt::test]
async fn test_err_in_service() {
let (client, server) = Io::create();
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");
let mut framed = Framed::new(server, BytesCodec);
framed.write_buf().extend(b"GET /test HTTP/1\r\n\r\n");
let disp = Dispatcher::new(
framed,
crate::fn_service(|_: BytesMut| async { Err::<Option<Bytes>, _>(()) }),
);
crate::rt::spawn(disp.map(|_| ()));
let buf = client.read_any();
assert_eq!(buf, Bytes::from_static(b""));
delay_for(Duration::from_millis(25)).await;
// buffer should be flushed
client.remote_buffer_cap(1024);
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
// write side must be closed, dispatcher waiting for read side to close
assert!(client.is_closed());
// close read side
client.close().await;
assert!(client.is_server_dropped());
}
}

View file

@ -2,7 +2,6 @@ pub mod buffer;
pub mod counter;
pub mod either;
mod extensions;
pub mod framed;
pub mod inflight;
pub mod keepalive;
pub mod stream;

View file

@ -156,11 +156,9 @@ where
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use futures::future::ok;
use futures::StreamExt;
use std::cell::Cell;
use std::rc::Rc;
use std::time::Duration;
use bytestring::ByteString;
use futures::{future::ok, StreamExt};
use std::{cell::Cell, rc::Rc, time::Duration};
use super::*;
use crate::channel::mpsc;
@ -183,7 +181,7 @@ mod tests {
encoder,
crate::fn_service(move |_| {
counter2.set(counter2.get() + 1);
ok(Some(ws::Message::Text("test".to_string())))
ok(Some(ws::Message::Text(ByteString::from_static("test"))))
}),
);
crate::rt::spawn(disp.map(|_| ()));
@ -191,7 +189,7 @@ mod tests {
let mut buf = BytesMut::new();
let codec = ws::Codec::new().client_mode();
codec
.encode(ws::Message::Text("test".to_string()), &mut buf)
.encode(ws::Message::Text(ByteString::from_static("test")), &mut buf)
.unwrap();
tx.send(Ok::<_, ()>(buf.split().freeze())).unwrap();

View file

@ -940,7 +940,7 @@ impl TestServer {
{
let url = self.url(path);
let connect = self.client.ws(url).connect();
connect.await.map(|(_, framed)| framed)
connect.await.map(|ws| ws.into_inner().1)
}
/// Connect to a websocket server

View file

@ -1,4 +1,5 @@
use bytes::{Bytes, BytesMut};
use bytestring::ByteString;
use std::cell::Cell;
use crate::codec::{Decoder, Encoder};
@ -11,7 +12,7 @@ use super::ProtocolError;
#[derive(Debug, PartialEq)]
pub enum Message {
/// Text message
Text(String),
Text(ByteString),
/// Binary message
Binary(Bytes),
/// Continuation
@ -60,7 +61,7 @@ pub struct Codec {
bitflags::bitflags! {
struct Flags: u8 {
const SERVER = 0b0000_0001;
const CONTINUATION = 0b0000_0010;
const R_CONTINUATION = 0b0000_0010;
const W_CONTINUATION = 0b0000_0100;
}
}
@ -222,7 +223,7 @@ impl Decoder for Codec {
if !finished {
return match opcode {
OpCode::Continue => {
if self.flags.get().contains(Flags::CONTINUATION) {
if self.flags.get().contains(Flags::R_CONTINUATION) {
Ok(Some(Frame::Continuation(Item::Continue(
payload
.map(|pl| pl.freeze())
@ -233,8 +234,8 @@ impl Decoder for Codec {
}
}
OpCode::Binary => {
if !self.flags.get().contains(Flags::CONTINUATION) {
self.insert_flags(Flags::CONTINUATION);
if !self.flags.get().contains(Flags::R_CONTINUATION) {
self.insert_flags(Flags::R_CONTINUATION);
Ok(Some(Frame::Continuation(Item::FirstBinary(
payload
.map(|pl| pl.freeze())
@ -245,8 +246,8 @@ impl Decoder for Codec {
}
}
OpCode::Text => {
if !self.flags.get().contains(Flags::CONTINUATION) {
self.insert_flags(Flags::CONTINUATION);
if !self.flags.get().contains(Flags::R_CONTINUATION) {
self.insert_flags(Flags::R_CONTINUATION);
Ok(Some(Frame::Continuation(Item::FirstText(
payload
.map(|pl| pl.freeze())
@ -265,8 +266,8 @@ impl Decoder for Codec {
match opcode {
OpCode::Continue => {
if self.flags.get().contains(Flags::CONTINUATION) {
self.remove_flags(Flags::CONTINUATION);
if self.flags.get().contains(Flags::R_CONTINUATION) {
self.remove_flags(Flags::R_CONTINUATION);
Ok(Some(Frame::Continuation(Item::Last(
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
))))

View file

@ -11,13 +11,24 @@ mod codec;
mod frame;
mod mask;
mod proto;
mod sink;
mod stream;
pub use self::codec::{Codec, Frame, Item, Message};
pub use self::frame::Parser;
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
pub use self::sink::WsSink;
pub use self::stream::{StreamDecoder, StreamEncoder};
/// Websocket service errors
#[derive(Debug, Display)]
pub enum WsError<E> {
Service(E),
KeepAlive,
Protocol(ProtocolError),
Io(io::Error),
}
/// Websocket protocol errors
#[derive(Debug, Display, From)]
pub enum ProtocolError {
@ -48,7 +59,4 @@ pub enum ProtocolError {
/// Unknown continuation fragment
#[display(fmt = "Unknown continuation fragment.")]
ContinuationFragment(OpCode),
/// Io error
#[display(fmt = "io error: {}", _0)]
Io(io::Error),
}

29
ntex/src/ws/sink.rs Normal file
View file

@ -0,0 +1,29 @@
use std::{future::Future, rc::Rc};
use crate::framed::State;
use crate::ws;
pub struct WsSink(Rc<WsSinkInner>);
struct WsSinkInner {
state: State,
codec: ws::Codec,
}
impl WsSink {
pub(crate) fn new(state: State, codec: ws::Codec) -> Self {
Self(Rc::new(WsSinkInner { state, codec }))
}
pub fn send(
&self,
item: ws::Message,
) -> impl Future<Output = Result<(), ws::ProtocolError>> {
let inner = self.0.clone();
async move {
inner.state.write_item(item, &inner.codec)?;
Ok(())
}
}
}

View file

@ -176,6 +176,7 @@ where
#[cfg(test)]
mod tests {
use bytestring::ByteString;
use futures::{SinkExt, StreamExt};
use super::*;
@ -189,10 +190,10 @@ mod tests {
let mut buf = BytesMut::new();
let codec = Codec::new().client_mode();
codec
.encode(Message::Text("test1".to_string()), &mut buf)
.encode(Message::Text(ByteString::from_static("test1")), &mut buf)
.unwrap();
codec
.encode(Message::Text("test2".to_string()), &mut buf)
.encode(Message::Text(ByteString::from_static("test2")), &mut buf)
.unwrap();
tx.send(Ok::<_, ()>(buf.split().freeze())).unwrap();
@ -214,7 +215,7 @@ mod tests {
let mut encoder = StreamEncoder::new(tx);
encoder
.send(Ok::<_, ()>(Message::Text("test".to_string())))
.send(Ok::<_, ()>(Message::Text(ByteString::from_static("test"))))
.await
.unwrap();
encoder.flush().await.unwrap();

View file

@ -1,8 +1,8 @@
use std::io;
use bytes::Bytes;
use futures::future::ok;
use futures::{SinkExt, StreamExt};
use bytestring::ByteString;
use futures::{future::ok, SinkExt, StreamExt};
use ntex::framed::{DispatchItem, Dispatcher, State};
use ntex::http::test::server as test_server;
@ -17,9 +17,9 @@ async fn ws_service(
let msg = match msg {
DispatchItem::Item(msg) => match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8(Vec::from(text.as_ref())).unwrap())
}
ws::Frame::Text(text) => ws::Message::Text(
String::from_utf8(Vec::from(text.as_ref())).unwrap().into(),
),
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => ws::Message::Close(None),
@ -65,7 +65,7 @@ async fn test_simple() {
// client service
let mut framed = srv.ws().await.unwrap();
framed
.send(ws::Message::Text("text".to_string()))
.send(ws::Message::Text(ByteString::from_static("text")))
.await
.unwrap();
let item = framed.next().await.unwrap().unwrap();

View file

@ -3,6 +3,7 @@ use std::task::{Context, Poll};
use std::{cell::Cell, io, marker::PhantomData, pin::Pin};
use bytes::Bytes;
use bytestring::ByteString;
use futures::{future, Future, SinkExt, StreamExt};
use ntex::codec::{AsyncRead, AsyncWrite};
@ -71,7 +72,7 @@ async fn service(
DispatchItem::Item(msg) => match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8_lossy(&text).to_string())
ws::Message::Text(String::from_utf8_lossy(&text).as_ref().into())
}
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
ws::Frame::Continuation(item) => ws::Message::Continuation(item),
@ -102,7 +103,7 @@ async fn test_simple() {
// client service
let mut framed = srv.ws().await.unwrap();
framed
.send(ws::Message::Text("text".to_string()))
.send(ws::Message::Text(ByteString::from_static("text")))
.await
.unwrap();
let (item, mut framed) = framed.into_future().await;

View file

@ -1,6 +1,7 @@
use std::io;
use bytes::Bytes;
use bytestring::ByteString;
use futures::{SinkExt, StreamExt};
use ntex::service::{fn_factory_with_config, fn_service};
@ -10,7 +11,7 @@ async fn service(msg: ws::Frame) -> Result<Option<ws::Message>, io::Error> {
let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8_lossy(&text).to_string())
ws::Message::Text(String::from_utf8_lossy(&text).as_ref().into())
}
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
ws::Frame::Close(reason) => ws::Message::Close(reason),
@ -39,7 +40,7 @@ async fn web_ws() {
// client service
let mut framed = srv.ws().await.unwrap();
framed
.send(ws::Message::Text("text".to_string()))
.send(ws::Message::Text(ByteString::from_static("text")))
.await
.unwrap();
let item = framed.next().await.unwrap().unwrap();