Refactor web websockets support (#97)

* Refactor ws handling
This commit is contained in:
Nikolay Kim 2022-01-23 19:56:56 +06:00 committed by GitHub
parent e5efdab4ed
commit 5d9a653f70
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 356 additions and 727 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.1.5] - 2022-01-23
* Add Eq,PartialEq,Hash,Debug impls to Io asn IoRef
## [0.1.4] - 2022-01-17
* Add Io::take() method

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "0.1.4"
version = "0.1.5"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
@ -16,7 +16,7 @@ name = "ntex_io"
path = "src/lib.rs"
[dependencies]
ntex-codec = "0.6.0"
ntex-codec = "0.6.1"
ntex-bytes = "0.1.9"
ntex-util = "0.1.9"
ntex-service = "0.3.1"

View file

@ -754,7 +754,6 @@ mod tests {
#[ntex::test]
async fn test_keepalive() {
env_logger::init();
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");

View file

@ -699,6 +699,39 @@ impl<F> AsRef<IoRef> for Io<F> {
}
}
impl<F> Eq for Io<F> {}
impl<F> PartialEq for Io<F> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
impl<F> hash::Hash for Io<F> {
#[inline]
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl<F> fmt::Debug for Io<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Io")
.field("open", &!self.is_closed())
.finish()
}
}
impl<F> Deref for Io<F> {
type Target = IoRef;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<F> Drop for Io<F> {
fn drop(&mut self) {
self.remove_keepalive_timer();
@ -727,23 +760,6 @@ impl<F> Drop for Io<F> {
}
}
impl fmt::Debug for Io {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Io")
.field("open", &!self.is_closed())
.finish()
}
}
impl<F> Deref for Io<F> {
type Target = IoRef;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
/// OnDisconnect future resolves when socket get disconnected
#[must_use = "OnDisconnect do nothing unless polled"]
pub struct OnDisconnect {

View file

@ -1,4 +1,4 @@
use std::{any, fmt, io};
use std::{any, fmt, hash, io};
use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_codec::{Decoder, Encoder};
@ -190,6 +190,22 @@ impl IoRef {
}
}
impl Eq for IoRef {}
impl PartialEq for IoRef {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
impl hash::Hash for IoRef {
#[inline]
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl fmt::Debug for IoRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IoRef")

View file

@ -5,6 +5,7 @@ use crate::{Filter, Io};
/// Sealed filter type
pub struct Sealed(pub(crate) Box<dyn Filter>);
#[derive(Debug)]
/// Boxed `Io` object with erased filter type
pub struct IoBoxed(Io<Sealed>);

View file

@ -16,5 +16,6 @@ syn = { version = "^1", features = ["full", "parsing"] }
proc-macro2 = "^1"
[dev-dependencies]
ntex = "0.5.0-b.0"
ntex = { version = "0.5.0", features = ["tokio"] }
futures = "0.3"
env_logger = "0.9"

View file

@ -115,10 +115,6 @@ async fn test_body() {
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(Method::CONNECT, srv.url("/test"));
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(Method::OPTIONS, srv.url("/test"));
let response = request.send().await.unwrap();
assert!(response.status().is_success());

View file

@ -691,12 +691,12 @@ mod tests {
#[derive(Debug, Deserialize)]
struct S {
inner: (String,),
_inner: (String,),
}
let s: Result<S, de::value::Error> =
de::Deserialize::deserialize(PathDeserializer::new(&path));
assert!(s.is_err());
assert!(format!("{:?}", s).contains("missing field `inner`"));
assert!(format!("{:?}", s).contains("missing field `_inner`"));
let path = Path::new("");
let s: Result<&str, de::value::Error> =

View file

@ -1,5 +1,9 @@
# Changes
## [0.1.11] - 2022-01-23
* Remove useless stream::Dispatcher and sink::SinkService
## [0.1.10] - 2022-01-17
* Add time::query_system_time(), it does not use async runtime

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-util"
version = "0.1.10"
version = "0.1.11"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for ntex framework"
keywords = ["network", "framework", "async", "futures"]

View file

@ -3,8 +3,6 @@ pub mod counter;
mod extensions;
pub mod inflight;
pub mod keepalive;
pub mod sink;
pub mod stream;
pub mod timeout;
pub mod variant;

View file

@ -1,82 +0,0 @@
use std::{
cell::Cell, cell::RefCell, marker::PhantomData, pin::Pin, task::Context, task::Poll,
};
use futures_sink::Sink;
use ntex_service::Service;
use crate::future::Ready;
/// `SinkService` forwards incoming requests to the provided `Sink`
pub struct SinkService<S, I> {
sink: RefCell<S>,
shutdown: Cell<bool>,
_t: PhantomData<I>,
}
impl<S, I> SinkService<S, I>
where
S: Sink<I> + Unpin,
{
/// Create new `SinnkService` instance
pub fn new(sink: S) -> Self {
SinkService {
sink: RefCell::new(sink),
shutdown: Cell::new(false),
_t: PhantomData,
}
}
}
impl<S, I> Clone for SinkService<S, I>
where
S: Clone,
{
fn clone(&self) -> Self {
SinkService {
sink: self.sink.clone(),
shutdown: self.shutdown.clone(),
_t: PhantomData,
}
}
}
impl<S, I> Service<I> for SinkService<S, I>
where
S: Sink<I> + Unpin,
{
type Response = ();
type Error = S::Error;
type Future = Ready<(), S::Error>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut inner = self.sink.borrow_mut();
let pending1 = Pin::new(&mut *inner).poll_flush(cx)?.is_pending();
let pending2 = Pin::new(&mut *inner).poll_ready(cx)?.is_pending();
if pending1 || pending2 {
Poll::Pending
} else {
Poll::Ready(Ok(()))
}
}
fn poll_shutdown(&self, cx: &mut Context<'_>, _: bool) -> Poll<()> {
if !self.shutdown.get() {
if Pin::new(&mut *self.sink.borrow_mut())
.poll_close(cx)
.is_pending()
{
Poll::Pending
} else {
self.shutdown.set(true);
Poll::Ready(())
}
} else {
Poll::Ready(())
}
}
fn call(&self, req: I) -> Self::Future {
Ready::from(Pin::new(&mut *self.sink.borrow_mut()).start_send(req))
}
}

View file

@ -1,212 +0,0 @@
use std::{fmt, future::Future, pin::Pin, task::Context, task::Poll};
use log::trace;
use ntex_service::{IntoService, Service};
use crate::channel::mpsc;
use crate::{future::poll_fn, Sink, Stream};
pin_project_lite::pin_project! {
pub struct Dispatcher<Req, R, S, T, U>
where
R: 'static,
S: Service<Req, Response = Option<R>>,
S: 'static,
T: Stream<Item = Result<Req, S::Error>>,
T: Unpin,
U: Sink<Result<R, S::Error>>,
U: Unpin,
{
#[pin]
service: S,
stream: T,
sink: Option<U>,
rx: mpsc::Receiver<Result<S::Response, S::Error>>,
shutdown: Option<bool>,
}
}
impl<Req, R, S, T, U> Dispatcher<Req, R, S, T, U>
where
R: 'static,
S: Service<Req, Response = Option<R>> + 'static,
S::Error: fmt::Debug,
T: Stream<Item = Result<Req, S::Error>> + Unpin,
U: Sink<Result<R, S::Error>> + Unpin + 'static,
U::Error: fmt::Debug,
{
pub fn new<F>(stream: T, sink: U, service: F) -> Self
where
F: IntoService<S, Req>,
{
Dispatcher {
stream,
sink: Some(sink),
service: service.into_service(),
rx: mpsc::channel().1,
shutdown: None,
}
}
}
impl<Req, R, S, T, U> Future for Dispatcher<Req, R, S, T, U>
where
R: 'static,
S: Service<Req, Response = Option<R>> + 'static,
S::Future: 'static,
S::Error: fmt::Debug + 'static,
T: Stream<Item = Result<Req, S::Error>> + Unpin,
U: Sink<Result<R, S::Error>> + Unpin + 'static,
U::Error: fmt::Debug,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
if let Some(is_err) = this.shutdown {
if let Some(mut sink) = this.sink.take() {
crate::spawn(async move {
if poll_fn(|cx| Pin::new(&mut sink).poll_flush(cx))
.await
.is_ok()
{
let _ = poll_fn(|cx| Pin::new(&mut sink).poll_close(cx)).await;
}
});
}
if this.service.poll_shutdown(cx, *is_err).is_pending() {
return Poll::Pending;
}
return Poll::Ready(());
}
loop {
match Pin::new(this.sink.as_mut().unwrap()).poll_ready(cx) {
Poll::Pending => {
match Pin::new(this.sink.as_mut().unwrap()).poll_flush(cx) {
Poll::Pending => break,
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(e)) => {
trace!("Sink flush failed: {:?}", e);
*this.shutdown = Some(true);
return self.poll(cx);
}
}
}
Poll::Ready(Ok(_)) => {
if let Poll::Ready(Some(item)) = Pin::new(&mut this.rx).poll_next(cx) {
match item {
Ok(Some(item)) => {
if let Err(e) = Pin::new(this.sink.as_mut().unwrap())
.start_send(Ok(item))
{
trace!("Failed to write to sink: {:?}", e);
*this.shutdown = Some(true);
return self.poll(cx);
}
continue;
}
Ok(None) => continue,
Err(e) => {
trace!("Stream is failed: {:?}", e);
let _ = Pin::new(this.sink.as_mut().unwrap())
.start_send(Err(e));
*this.shutdown = Some(true);
return self.poll(cx);
}
}
}
}
Poll::Ready(Err(e)) => {
trace!("Sink readiness check failed: {:?}", e);
*this.shutdown = Some(true);
return self.poll(cx);
}
}
break;
}
loop {
return match this.service.poll_ready(cx) {
Poll::Ready(Ok(_)) => match Pin::new(&mut this.stream).poll_next(cx) {
Poll::Ready(Some(Ok(item))) => {
let tx = this.rx.sender();
let fut = this.service.call(item);
crate::spawn(async move {
let res = fut.await;
let _ = tx.send(res);
});
this = self.as_mut().project();
continue;
}
Poll::Pending => Poll::Pending,
Poll::Ready(Some(Err(_))) => {
*this.shutdown = Some(true);
return self.poll(cx);
}
Poll::Ready(None) => {
*this.shutdown = Some(false);
return self.poll(cx);
}
},
Poll::Ready(Err(e)) => {
trace!("Service readiness check failed: {:?}", e);
*this.shutdown = Some(true);
return self.poll(cx);
}
Poll::Pending => Poll::Pending,
};
}
}
}
#[cfg(test)]
mod tests {
use std::{cell::Cell, rc::Rc};
use ntex::{codec::Encoder, ws};
use ntex_bytes::{ByteString, BytesMut};
use super::*;
use crate::{channel::mpsc, future::stream_recv, time::sleep, time::Millis};
#[ntex_macros::rt_test2]
async fn test_basic() {
let counter = Rc::new(Cell::new(0));
let counter2 = counter.clone();
let (tx1, mut rx) = mpsc::channel();
let (tx, rx2) = mpsc::channel();
let encoder = ws::StreamEncoder::new(tx1);
let decoder = ws::StreamDecoder::new(rx2);
let disp = Dispatcher::new(
decoder,
encoder,
ntex_service::fn_service(move |_| {
counter2.set(counter2.get() + 1);
async { Ok(Some(ws::Message::Text(ByteString::from_static("test")))) }
}),
);
crate::spawn(async move {
let _ = disp.await;
});
let mut buf = BytesMut::new();
let codec = ws::Codec::new().client_mode();
codec
.encode(ws::Message::Text(ByteString::from_static("test")), &mut buf)
.unwrap();
tx.send(Ok::<_, ()>(buf.split().freeze())).unwrap();
let data = stream_recv(&mut rx).await.unwrap().unwrap();
assert_eq!(data, b"\x81\x04test".as_ref());
drop(tx);
sleep(Millis(10)).await;
assert!(stream_recv(&mut rx).await.is_none());
assert_eq!(counter.get(), 1);
}
}

View file

@ -1,6 +1,8 @@
# Changes
## [0.5.11] - 2022-01-xx
## [0.5.11] - 2022-01-23
* web: Refactor ws support
* web: Add types::Payload::recv() and types::Payload::poll_recv() methods

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.5.10"
version = "0.5.11"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"
@ -52,11 +52,11 @@ ntex-codec = "0.6.1"
ntex-router = "0.5.1"
ntex-service = "0.3.1"
ntex-macros = "0.1.3"
ntex-util = "0.1.9"
ntex-util = "0.1.10"
ntex-bytes = "0.1.9"
ntex-tls = "0.1.2"
ntex-rt = "0.4.1"
ntex-io = "0.1.4"
ntex-rt = "0.4.3"
ntex-io = "0.1.5"
ntex-tokio = "0.1.2"
ntex-glommio = { version = "0.1.0", optional = true }
ntex-async-std = { version = "0.1.0", optional = true }

View file

@ -256,10 +256,12 @@ impl MessageType for Request {
PayloadLength::Payload(pl) => pl,
PayloadLength::Upgrade => {
// upgrade(websocket)
msg.head_mut().set_upgrade();
PayloadType::Stream(PayloadDecoder::eof())
}
PayloadLength::None => {
if method == Method::CONNECT {
msg.head_mut().set_upgrade();
PayloadType::Stream(PayloadDecoder::eof())
} else {
PayloadType::None

View file

@ -1,14 +1,15 @@
//! Framed transport dispatcher
use std::task::{Context, Poll};
use std::{error::Error, future::Future, io, marker, pin::Pin, rc::Rc};
use std::{cell::RefCell, error::Error, future::Future, io, marker, pin::Pin, rc::Rc};
use crate::io::{Filter, Io, IoRef, RecvError};
use crate::io::{Filter, Io, IoBoxed, RecvError};
use crate::{service::Service, util::ready, util::Bytes};
use crate::http;
use crate::http::body::{BodySize, MessageBody, ResponseBody};
use crate::http::config::DispatcherConfig;
use crate::http::error::{DispatchError, ParseError, PayloadError, ResponseError};
use crate::http::message::CurrentIo;
use crate::http::request::Request;
use crate::http::response::Response;
@ -26,8 +27,10 @@ bitflags::bitflags! {
const KEEPALIVE_REG = 0b0000_0100;
/// Upgrade request
const UPGRADE = 0b0000_1000;
/// Handling upgrade
const UPGRADE_HND = 0b0001_0000;
/// Stop after sending payload
const SENDPAYLOAD_AND_STOP = 0b0001_0000;
const SENDPAYLOAD_AND_STOP = 0b0010_0000;
}
}
@ -52,6 +55,8 @@ enum State<B> {
},
#[display(fmt = "State::Upgrade")]
Upgrade(Option<Request>),
#[display(fmt = "State::StopIo")]
StopIo(Box<(IoBoxed, Codec)>),
Stop,
}
@ -60,6 +65,7 @@ pin_project_lite::pin_project! {
enum CallState<S: Service<Request>, X: Service<Request>> {
None,
Service { #[pin] fut: S::Future },
ServiceUpgrade { #[pin] fut: S::Future },
Expect { #[pin] fut: X::Future },
Filter { fut: Pin<Box<dyn Future<Output = Result<Request, Response>>>> }
}
@ -69,7 +75,6 @@ struct DispatcherInner<F, S, B, X, U> {
io: Io<F>,
flags: Flags,
codec: Codec,
state: IoRef,
config: Rc<DispatcherConfig<S, X, U>>,
error: Option<DispatchError>,
payload: Option<(PayloadDecoder, PayloadSender)>,
@ -89,7 +94,6 @@ where
{
/// Construct new `Dispatcher` instance with outgoing messages stream.
pub(in crate::http) fn new(io: Io<F>, config: Rc<DispatcherConfig<S, X, U>>) -> Self {
let state = io.get_ref();
let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled());
io.set_disconnect_timeout(config.client_disconnect.into());
@ -102,7 +106,6 @@ where
inner: DispatcherInner {
io,
codec,
state,
config,
flags: Flags::KEEPALIVE_REG,
error: None,
@ -113,11 +116,6 @@ where
}
}
macro_rules! set_error ({ $slf:tt, $err:ident } => {
*$slf.st = State::Stop;
$slf.inner.error = Some($err);
});
impl<F, S, B, X, U> Future for Dispatcher<F, S, B, X, U>
where
F: Filter,
@ -143,7 +141,7 @@ where
Poll::Ready(result) => match result {
Ok(res) => {
let (res, body) = res.into().into_parts();
*this.st = this.inner.send_response(res, body)
*this.st = this.inner.send_response(res, body);
}
Err(e) => *this.st = this.inner.handle_error(e, false),
},
@ -154,7 +152,8 @@ where
if let Err(e) =
ready!(this.inner.poll_request_payload(cx))
{
set_error!(this, e);
*this.st = State::Stop;
this.inner.error = Some(e);
}
} else {
return Poll::Pending;
@ -163,6 +162,43 @@ where
}
None
}
// special handling for upgrade requests.
// we cannot continue to handle requests, because Io<F> get
// converted to IoBoxed before we set it to request,
// so we have to send response and disconnect. request payload
// handling should be handled by service
CallStateProject::ServiceUpgrade { fut } => {
let result = ready!(fut.poll(cx));
match result {
Ok(res) => {
let (msg, body) = res.into().into_parts();
let item = if let Some(item) = msg.head().take_io() {
item
} else {
return Poll::Ready(Ok(()));
};
let _ = item
.0
.encode(Message::Item((msg, body.size())), &item.1);
match body.size() {
BodySize::None | BodySize::Empty => {}
_ => {
log::error!("Stream responses are not supported for upgrade requests");
}
}
*this.st = State::StopIo(item);
}
Err(e) => {
log::error!(
"Cannot handle error for upgrade handler: {:?}",
e
);
return Poll::Ready(Ok(()));
}
}
None
}
// handle EXPECT call
// expect service call must resolve before
// we can do any more io processing.
@ -170,7 +206,7 @@ where
// TODO: check keep-alive timer interaction
CallStateProject::Expect { fut } => match ready!(fut.poll(cx)) {
Ok(req) => {
let result = this.inner.state.with_write_buf(|buf| {
let result = this.inner.io.with_write_buf(|buf| {
buf.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n")
});
if result.is_err() {
@ -181,6 +217,11 @@ where
*this.st = State::Upgrade(Some(req));
this = self.as_mut().project();
continue;
} else if this.inner.flags.contains(Flags::UPGRADE_HND) {
// Handle upgrade requests
Some(CallState::ServiceUpgrade {
fut: this.inner.config.service.call(req),
})
} else {
Some(CallState::Service {
fut: this.inner.config.service.call(req),
@ -204,6 +245,12 @@ where
Some(CallState::Expect {
fut: this.inner.config.expect.call(req),
})
} else if this.inner.flags.contains(Flags::UPGRADE_HND)
{
// Handle upgrade requests
Some(CallState::ServiceUpgrade {
fut: this.inner.config.service.call(req),
})
} else {
// Handle normal requests
Some(CallState::Service {
@ -238,7 +285,6 @@ where
req,
pl
);
req.head_mut().io = Some(this.inner.state.clone());
// configure request payload
let upgrade = match pl {
@ -272,18 +318,38 @@ where
log::trace!("prep io for upgrade handler");
*this.st = State::Upgrade(Some(req));
} else {
if req.upgrade() {
this.inner.flags.insert(Flags::UPGRADE_HND);
let io: IoBoxed = this.inner.io.take().into();
req.head_mut().io = CurrentIo::Io(Rc::new((
io.get_ref(),
RefCell::new(Some(Box::new((
io,
this.inner.codec.clone(),
)))),
)));
} else {
req.head_mut().io =
CurrentIo::Ref(this.inner.io.get_ref());
}
*this.st = State::Call;
this.call.set(
if let Some(ref f) = this.inner.config.on_request {
// Handle filter fut
CallState::Filter {
fut: f.call((req, this.inner.state.clone())),
fut: f.call((req, this.inner.io.get_ref())),
}
} else if req.head().expect() {
// Handle normal requests with EXPECT: 100-Continue` header
CallState::Expect {
fut: this.inner.config.expect.call(req),
}
} else if this.inner.flags.contains(Flags::UPGRADE_HND)
{
// Handle upgrade requests
CallState::ServiceUpgrade {
fut: this.inner.config.service.call(req),
}
} else {
// Handle normal requests
CallState::Service {
@ -401,6 +467,10 @@ where
Poll::Ready(Ok(()))
};
}
// prepare to shutdown
State::StopIo(ref item) => {
return item.0.poll_shutdown(cx).map_err(From::from)
}
}
}
}
@ -416,7 +486,7 @@ where
fn switch_to_read_request(&mut self) -> State<B> {
// connection is not keep-alive, disconnect
if !self.flags.contains(Flags::KEEPALIVE) || !self.codec.keepalive_enabled() {
self.state.close();
self.io.close();
State::Stop
} else {
State::ReadRequest
@ -457,7 +527,7 @@ where
// we dont need to process responses if socket is disconnected
// but we still want to handle requests with app service
// so we skip response processing for droppped connection
if self.state.is_closed() {
if self.io.is_closed() {
State::Stop
} else {
let result = self
@ -751,14 +821,14 @@ mod tests {
sleep(Millis(50)).await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
assert!(h1.inner.state.is_closed());
assert!(h1.inner.io.is_closed());
sleep(Millis(50)).await;
client.local_buffer(|buf| assert_eq!(&buf[..26], b"HTTP/1.1 400 Bad Request\r\n"));
client.close().await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
assert!(h1.inner.state.is_closed());
assert!(h1.inner.io.is_closed());
}
#[crate::rt_test]
@ -916,7 +986,7 @@ mod tests {
let _ = lazy(|cx| Pin::new(&mut h1).poll(cx)).await;
sleep(Millis(550)).await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
assert!(h1.inner.state.is_closed());
assert!(h1.inner.io.is_closed());
let mut buf = client.read().await.unwrap();
assert_eq!(load(&mut decoder, &mut buf).status, StatusCode::BAD_REQUEST);
@ -990,7 +1060,7 @@ mod tests {
Ok::<_, io::Error>(Response::Ok().message_body(Stream(n.clone())))
})
});
let state = h1.inner.state.clone();
let state = h1.inner.io.get_ref();
// do not allow to write to socket
client.remote_buffer_cap(0);
@ -1084,7 +1154,7 @@ mod tests {
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
sleep(Millis(50)).await;
assert!(h1.inner.state.is_closed());
assert!(h1.inner.io.is_closed());
let buf = client.local_buffer(|buf| buf.split().freeze());
assert_eq!(&buf[..28], b"HTTP/1.1 500 Internal Server");
assert_eq!(&buf[buf.len() - 5..], b"error");

View file

@ -11,9 +11,8 @@ use crate::http::error::{DispatchError, ResponseError};
use crate::http::header::{
HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING,
};
use crate::http::{
message::ResponseHead, payload::Payload, request::Request, response::Response,
};
use crate::http::message::{CurrentIo, ResponseHead};
use crate::http::{payload::Payload, request::Request, response::Response};
use crate::io::{IoRef, TokioIoBoxed};
use crate::service::Service;
use crate::time::{now, Sleep};
@ -105,7 +104,7 @@ where
head.method = parts.method;
head.version = parts.version;
head.headers = parts.headers.into();
head.io = Some(this.io.clone());
head.io = CurrentIo::Ref(this.io.clone());
crate::rt::spawn(ServiceResponse {
state: ServiceResponseState::ServiceCall {

View file

@ -3,8 +3,8 @@ use std::{cell::Ref, cell::RefCell, cell::RefMut, net, rc::Rc};
use bitflags::bitflags;
use crate::http::header::HeaderMap;
use crate::http::{header, Method, StatusCode, Uri, Version};
use crate::io::{types, IoRef};
use crate::http::{h1::Codec, Method, StatusCode, Uri, Version};
use crate::io::{types, IoBoxed, IoRef};
use crate::util::Extensions;
/// Represents various types of connection
@ -28,7 +28,6 @@ bitflags! {
}
}
#[doc(hidden)]
pub(crate) trait Head: Default + 'static {
fn clear(&mut self);
@ -37,6 +36,23 @@ pub(crate) trait Head: Default + 'static {
F: FnOnce(&MessagePool<Self>) -> R;
}
#[derive(Clone, Debug)]
pub(crate) enum CurrentIo {
Ref(IoRef),
Io(Rc<(IoRef, RefCell<Option<Box<(IoBoxed, Codec)>>>)>),
None,
}
impl CurrentIo {
pub(crate) fn as_ref(&self) -> Option<&IoRef> {
match self {
CurrentIo::Ref(ref io) => Some(io),
CurrentIo::Io(ref io) => Some(&io.0),
CurrentIo::None => None,
}
}
}
#[derive(Debug)]
pub struct RequestHead {
pub uri: Uri,
@ -44,14 +60,14 @@ pub struct RequestHead {
pub version: Version,
pub headers: HeaderMap,
pub extensions: RefCell<Extensions>,
pub io: Option<IoRef>,
pub(crate) io: CurrentIo,
pub(crate) flags: Flags,
}
impl Default for RequestHead {
fn default() -> RequestHead {
RequestHead {
io: None,
io: CurrentIo::None,
uri: Uri::default(),
method: Method::default(),
version: Version::HTTP_11,
@ -64,7 +80,7 @@ impl Default for RequestHead {
impl Head for RequestHead {
fn clear(&mut self) {
self.io = None;
self.io = CurrentIo::None;
self.flags = Flags::empty();
self.headers.clear();
self.extensions.get_mut().clear();
@ -127,17 +143,16 @@ impl RequestHead {
}
}
#[inline]
/// Connection upgrade status
pub fn upgrade(&self) -> bool {
if let Some(hdr) = self.headers().get(header::CONNECTION) {
if let Ok(s) = hdr.to_str() {
s.to_ascii_lowercase().contains("upgrade")
} else {
false
}
} else {
false
}
self.flags.contains(Flags::UPGRADE)
}
#[inline]
/// Request contains `EXPECT` header
pub fn expect(&self) -> bool {
self.flags.contains(Flags::EXPECT)
}
#[inline]
@ -156,14 +171,13 @@ impl RequestHead {
}
#[inline]
/// Request contains `EXPECT` header
pub fn expect(&self) -> bool {
self.flags.contains(Flags::EXPECT)
pub(crate) fn set_expect(&mut self) {
self.flags.insert(Flags::EXPECT);
}
#[inline]
pub(crate) fn set_expect(&mut self) {
self.flags.insert(Flags::EXPECT);
pub(crate) fn set_upgrade(&mut self) {
self.flags.insert(Flags::UPGRADE);
}
/// Peer socket address
@ -178,6 +192,16 @@ impl RequestHead {
.map(types::PeerAddr::into_inner)
})
}
/// Take io and codec for current request
///
/// This objects are set only for upgrade requests
pub fn take_io(&self) -> Option<Box<(IoBoxed, Codec)>> {
match self.io {
CurrentIo::Io(ref inner) => inner.1.borrow_mut().take(),
_ => None,
}
}
}
#[derive(Debug)]
@ -216,6 +240,7 @@ pub struct ResponseHead {
pub status: StatusCode,
pub headers: HeaderMap,
pub reason: Option<&'static str>,
pub(crate) io: CurrentIo,
pub(crate) extensions: RefCell<Extensions>,
flags: Flags,
}
@ -230,6 +255,7 @@ impl ResponseHead {
headers: HeaderMap::with_capacity(12),
reason: None,
flags: Flags::empty(),
io: CurrentIo::None,
extensions: RefCell::new(Extensions::new()),
}
}
@ -335,6 +361,17 @@ impl ResponseHead {
self.flags.remove(Flags::NO_CHUNKING);
}
}
pub(crate) fn set_io(&mut self, head: &RequestHead) {
self.io = head.io.clone();
}
pub(crate) fn take_io(&self) -> Option<Box<(IoBoxed, Codec)>> {
match self.io {
CurrentIo::Io(ref inner) => inner.1.borrow_mut().take(),
_ => None,
}
}
}
impl Default for ResponseHead {
@ -347,6 +384,7 @@ impl Head for ResponseHead {
fn clear(&mut self) {
self.reason = None;
self.headers.clear();
self.io = CurrentIo::None;
self.flags = Flags::empty();
}

View file

@ -10,7 +10,7 @@ use crate::{time::Millis, time::Seconds, util::Bytes};
use super::client::{Client, ClientRequest, ClientResponse, Connector};
use super::error::{HttpError, PayloadError};
use super::header::{HeaderMap, HeaderName, HeaderValue};
use super::header::{self, HeaderMap, HeaderName, HeaderValue};
use super::payload::Payload;
use super::{Method, Request, Uri, Version};
@ -148,6 +148,14 @@ impl TestRequest {
head.version = inner.version;
head.headers = inner.headers;
if let Some(conn) = head.headers.get(header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if s.to_lowercase().contains("upgrade") {
head.set_upgrade()
}
}
}
#[cfg(feature = "cookie")]
{
use percent_encoding::percent_encode;

View file

@ -108,9 +108,7 @@ impl WebResponse {
pub fn take_body(&mut self) -> ResponseBody<Body> {
self.response.take_body()
}
}
impl WebResponse {
/// Set a new body
pub fn map_body<F>(self, f: F) -> WebResponse
where
@ -126,7 +124,11 @@ impl WebResponse {
}
impl From<WebResponse> for Response<Body> {
fn from(res: WebResponse) -> Response<Body> {
fn from(mut res: WebResponse) -> Response<Body> {
let head = res.response.head_mut();
if head.upgrade() {
head.set_io(res.request.head());
}
res.response
}
}

View file

@ -1,111 +1,106 @@
use std::{error, marker::PhantomData, pin::Pin, task::Context, task::Poll};
//! WebSockets protocol support
use std::fmt;
pub use crate::ws::{CloseCode, CloseReason, Frame, Message};
pub use crate::ws::{CloseCode, CloseReason, Frame, Message, WsSink};
use crate::http::body::{Body, BoxedBodyStream};
use crate::http::error::PayloadError;
use crate::service::{IntoServiceFactory, Service, ServiceFactory};
use crate::http::{body::BodySize, h1, StatusCode};
use crate::service::{
apply_fn, fn_factory_with_config, IntoServiceFactory, Service, ServiceFactory,
};
use crate::web::{HttpRequest, HttpResponse};
use crate::ws::{error::HandshakeError, handshake};
use crate::{channel::mpsc, rt, util::Bytes, util::Sink, util::Stream, ws};
use crate::ws::{error::HandshakeError, error::WsError, handshake};
use crate::{io::DispatchItem, rt, util::Either, util::Ready, ws};
pub type WebSocketsSink =
ws::StreamEncoder<mpsc::Sender<Result<Bytes, Box<dyn error::Error>>>>;
// TODO: fix close frame handling
/// Do websocket handshake and start websockets service.
pub async fn start<T, F, S, Err>(
req: HttpRequest,
payload: S,
factory: F,
) -> Result<HttpResponse, Err>
pub async fn start<T, F, Err>(req: HttpRequest, factory: F) -> Result<HttpResponse, Err>
where
T: ServiceFactory<Frame, WebSocketsSink, Response = Option<Message>> + 'static,
T::Error: error::Error,
F: IntoServiceFactory<T, Frame, WebSocketsSink>,
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static,
T: ServiceFactory<Frame, WsSink, Response = Option<Message>> + 'static,
T::Error: fmt::Debug,
F: IntoServiceFactory<T, Frame, WsSink>,
Err: From<T::InitError> + From<HandshakeError>,
{
let (tx, rx) = mpsc::channel();
let inner_factory = factory.into_factory().map_err(WsError::Service);
start_with(req, payload, tx, rx, factory).await
let factory = fn_factory_with_config(move |sink: WsSink| {
let fut = inner_factory.new_service(sink.clone());
async move {
let srv = fut.await?;
Ok::<_, T::InitError>(apply_fn(srv, move |req, srv| match req {
DispatchItem::Item(item) => {
let s = if matches!(item, Frame::Close(_)) {
Some(sink.clone())
} else {
None
};
let fut = srv.call(item);
Either::Left(async move {
let result = fut.await;
if let Some(s) = s {
rt::spawn(async move { s.io().close() });
}
result
})
}
DispatchItem::WBackPressureEnabled
| DispatchItem::WBackPressureDisabled => Either::Right(Ready::Ok(None)),
DispatchItem::KeepAliveTimeout => {
Either::Right(Ready::Err(WsError::KeepAlive))
}
DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => {
Either::Right(Ready::Err(WsError::Protocol(e)))
}
DispatchItem::Disconnect(e) => {
Either::Right(Ready::Err(WsError::Disconnected(e)))
}
}))
}
});
start_with(req, factory).await
}
/// Do websocket handshake and start websockets service.
pub async fn start_with<T, F, S, Err, Tx, Rx>(
pub async fn start_with<T, F, Err>(
req: HttpRequest,
payload: S,
tx: Tx,
rx: Rx,
factory: F,
) -> Result<HttpResponse, Err>
where
T: ServiceFactory<Frame, ws::StreamEncoder<Tx>, Response = Option<Message>> + 'static,
T::Error: error::Error,
F: IntoServiceFactory<T, Frame, ws::StreamEncoder<Tx>>,
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static,
T: ServiceFactory<DispatchItem<ws::Codec>, WsSink, Response = Option<Message>>
+ 'static,
T::Error: fmt::Debug,
F: IntoServiceFactory<T, DispatchItem<ws::Codec>, WsSink>,
Err: From<T::InitError> + From<HandshakeError>,
Tx: Sink<Result<Bytes, Box<dyn error::Error>>> + Clone + Unpin + 'static,
Tx::Error: error::Error,
Rx: Stream<Item = Result<Bytes, Box<dyn error::Error>>> + Unpin + 'static,
{
// ws handshake
let mut res = handshake(req.head())?;
log::trace!("Start ws handshake verification for {:?}", req.path());
// converter wraper from ws::Message to Bytes
let sink = ws::StreamEncoder::new(tx);
// ws handshake
let res = handshake(req.head())?.finish().into_parts().0;
// extract io
let item = req
.head()
.take_io()
.ok_or(HandshakeError::NoWebsocketUpgrade)?;
let io = item.0;
let codec = item.1;
io.encode(h1::Message::Item((res, BodySize::Empty)), &codec)
.map_err(|_| HandshakeError::NoWebsocketUpgrade)?;
log::trace!("Ws handshake verification completed for {:?}", req.path());
// create sink
let codec = ws::Codec::new();
let sink = WsSink::new(io.get_ref(), codec.clone());
// create ws service
let srv = factory
.into_factory()
.new_service(sink.clone())
.await?
.map_err(|e| {
let e: Box<dyn error::Error> = Box::new(e);
e
});
let srv = factory.into_factory().new_service(sink).await?;
// start websockets service dispatcher
rt::spawn(crate::util::stream::Dispatcher::new(
// wrap bytes stream to ws::Frame's stream
MapStream {
stream: ws::StreamDecoder::new(payload),
_t: PhantomData,
},
// converter wraper from ws::Message to Bytes
sink,
// websockets handler service
srv,
));
rt::spawn(async move {
let res = crate::io::Dispatcher::new(io, codec, srv).await;
log::trace!("Ws handler is terminated: {:?}", res);
});
Ok(res.body(Body::from_message(BoxedBodyStream::new(rx))))
}
pin_project_lite::pin_project! {
struct MapStream<S, I, E>{
#[pin]
stream: S,
_t: PhantomData<(I, E)>,
}
}
impl<S, I, E> Stream for MapStream<S, I, E>
where
S: Stream<Item = Result<I, E>>,
E: error::Error + 'static,
{
type Item = Result<I, Box<dyn error::Error>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.project().stream.poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(Ok(item))) => Poll::Ready(Some(Ok(item))),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(Box::new(err)))),
Poll::Ready(None) => Poll::Ready(None),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
Ok(HttpResponse::new(StatusCode::OK))
}

View file

@ -17,7 +17,7 @@ use crate::http::{body::BodySize, client::ClientResponse, error::HttpError, h1};
use crate::http::{ConnectionType, RequestHead, RequestHeadType, StatusCode, Uri};
use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Sealed};
use crate::service::{apply_fn, into_service, IntoService, Service};
use crate::util::{sink, Either, Ready};
use crate::util::{Either, Ready};
use crate::{channel::mpsc, rt, time::timeout, time::Millis, ws};
use super::error::{WsClientBuilderError, WsClientError, WsError};
@ -695,29 +695,25 @@ impl<F> WsConnection<F> {
impl WsConnection<Sealed> {
// TODO: fix close frame handling
/// Start client websockets with `SinkService` and `mpsc::Receiver<Frame>`
pub fn start_default(self) -> mpsc::Receiver<Result<ws::Frame, WsError<()>>> {
pub fn receiver(self) -> mpsc::Receiver<Result<ws::Frame, WsError<()>>> {
let (tx, rx): (_, mpsc::Receiver<Result<ws::Frame, WsError<()>>>) = mpsc::channel();
rt::spawn(async move {
let tx2 = tx.clone();
let io = self.io.get_ref();
let srv = sink::SinkService::new(tx.clone()).map(|_| None);
if let Err(err) = self
.start(into_service(move |item| {
let io = io.clone();
let close = matches!(item, ws::Frame::Close(_));
let fut = srv.call(Ok::<_, WsError<()>>(item));
async move {
let result = fut.await.map_err(|_| ());
if close {
io.close();
}
result
}
let result = self
.start(into_service(move |item: ws::Frame| {
match tx.send(Ok(item)) {
Ok(()) => (),
Err(_) => io.close(),
};
Ready::Ok::<Option<ws::Message>, ()>(None)
}))
.await
{
let _ = tx.send(Err(err));
.await;
if let Err(e) = result {
let _ = tx2.send(Err(e));
}
});

View file

@ -10,7 +10,6 @@ mod handshake;
mod mask;
mod proto;
mod sink;
mod stream;
mod transport;
pub mod error;
@ -21,5 +20,4 @@ pub use self::frame::Parser;
pub use self::handshake::{handshake, handshake_response, verify_handshake};
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
pub use self::sink::WsSink;
pub use self::stream::{StreamDecoder, StreamEncoder};
pub use self::transport::{WsTransport, WsTransportFactory};

View file

@ -3,6 +3,7 @@ use std::{future::Future, rc::Rc};
use crate::io::{IoRef, OnDisconnect};
use crate::ws;
#[derive(Clone)]
pub struct WsSink(Rc<WsSinkInner>);
struct WsSinkInner {
@ -15,6 +16,11 @@ impl WsSink {
Self(Rc::new(WsSinkInner { io, codec }))
}
/// Io reference
pub fn io(&self) -> &IoRef {
&self.0.io
}
/// Endcode and send message to the peer.
pub fn send(
&self,

View file

@ -1,225 +0,0 @@
use std::{
cell::RefCell, fmt, marker::PhantomData, pin::Pin, rc::Rc, task::Context, task::Poll,
};
use super::{error::ProtocolError, Codec, Frame, Message};
use crate::util::{Bytes, BytesMut, Sink, Stream};
use crate::{codec::Decoder, codec::Encoder};
/// Stream error
#[derive(Debug, Display)]
pub enum StreamError<E: fmt::Debug> {
#[display(fmt = "StreamError::Stream({:?})", _0)]
Stream(E),
Protocol(ProtocolError),
}
impl<E: fmt::Debug> std::error::Error for StreamError<E> {}
impl<E: fmt::Debug> From<ProtocolError> for StreamError<E> {
fn from(err: ProtocolError) -> Self {
StreamError::Protocol(err)
}
}
pin_project_lite::pin_project! {
/// Stream ws protocol decoder.
pub struct StreamDecoder<S, E> {
#[pin]
stream: S,
codec: Codec,
buf: BytesMut,
_t: PhantomData<E>,
}
}
impl<S, E> StreamDecoder<S, E> {
pub fn new(stream: S) -> Self {
StreamDecoder::with(stream, Codec::new())
}
pub fn with(stream: S, codec: Codec) -> Self {
StreamDecoder {
stream,
codec,
buf: BytesMut::new(),
_t: PhantomData,
}
}
}
impl<S, E> Stream for StreamDecoder<S, E>
where
S: Stream<Item = Result<Bytes, E>>,
E: fmt::Debug,
{
type Item = Result<Frame, StreamError<E>>;
#[inline]
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut().project();
loop {
if !this.buf.is_empty() {
match this.codec.decode(this.buf) {
Ok(Some(item)) => return Poll::Ready(Some(Ok(item))),
Ok(None) => (),
Err(err) => return Poll::Ready(Some(Err(err.into()))),
}
}
match this.stream.poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(buf))) => {
this.buf.extend(&buf);
this = self.as_mut().project();
}
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Some(Err(StreamError::Stream(err))))
}
Poll::Ready(None) => return Poll::Ready(None),
}
}
}
}
pin_project_lite::pin_project! {
/// Stream ws protocol decoder.
#[derive(Clone)]
pub struct StreamEncoder<S> {
#[pin]
sink: S,
codec: Rc<RefCell<Codec>>,
}
}
impl<S> StreamEncoder<S> {
pub fn new(sink: S) -> Self {
StreamEncoder::with(sink, Codec::new())
}
pub fn with(sink: S, codec: Codec) -> Self {
StreamEncoder {
sink,
codec: Rc::new(RefCell::new(codec)),
}
}
}
impl<S, E> Sink<Result<Message, E>> for StreamEncoder<S>
where
S: Sink<Result<Bytes, E>>,
S::Error: fmt::Debug,
{
type Error = StreamError<S::Error>;
#[inline]
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project()
.sink
.poll_ready(cx)
.map_err(StreamError::Stream)
}
fn start_send(
self: Pin<&mut Self>,
item: Result<Message, E>,
) -> Result<(), Self::Error> {
let this = self.project();
match item {
Ok(item) => {
let mut buf = BytesMut::new();
this.codec.borrow_mut().encode(item, &mut buf)?;
this.sink
.start_send(Ok(buf.freeze()))
.map_err(StreamError::Stream)
}
Err(e) => this.sink.start_send(Err(e)).map_err(StreamError::Stream),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project()
.sink
.poll_flush(cx)
.map_err(StreamError::Stream)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.project()
.sink
.poll_close(cx)
.map_err(StreamError::Stream)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
channel::mpsc, util::poll_fn, util::send, util::stream_recv, util::ByteString,
};
#[crate::rt_test]
async fn test_decoder() {
let (tx, rx) = mpsc::channel();
let mut decoder = StreamDecoder::new(rx);
let mut buf = BytesMut::new();
let codec = Codec::new().client_mode();
codec
.encode(Message::Text(ByteString::from_static("test1")), &mut buf)
.unwrap();
codec
.encode(Message::Text(ByteString::from_static("test2")), &mut buf)
.unwrap();
tx.send(Ok::<_, ()>(buf.split().freeze())).unwrap();
let frame = stream_recv(&mut decoder).await.unwrap().unwrap();
match frame {
Frame::Text(data) => assert_eq!(data, b"test1"[..]),
_ => panic!(),
}
let frame = stream_recv(&mut decoder).await.unwrap().unwrap();
match frame {
Frame::Text(data) => assert_eq!(data, b"test2"[..]),
_ => panic!(),
}
}
#[crate::rt_test]
async fn test_encoder() {
let (tx, mut rx) = mpsc::channel();
let mut encoder = StreamEncoder::new(tx);
send(
&mut encoder,
Ok::<_, ()>(Message::Text(ByteString::from_static("test"))),
)
.await
.unwrap();
poll_fn(|cx| Pin::new(&mut encoder).poll_flush(cx))
.await
.unwrap();
poll_fn(|cx| Pin::new(&mut encoder).poll_close(cx))
.await
.unwrap();
let data = stream_recv(&mut rx).await.unwrap().unwrap();
assert_eq!(data, b"\x81\x04test".as_ref());
assert!(stream_recv(&mut rx).await.is_none());
}
}

View file

@ -1,6 +1,5 @@
use std::io;
use futures_util::StreamExt;
use ntex::http::StatusCode;
use ntex::service::{fn_factory_with_config, fn_service};
use ntex::util::{ByteString, Bytes};
@ -23,10 +22,9 @@ async fn service(msg: ws::Frame) -> Result<Option<ws::Message>, io::Error> {
async fn web_ws() {
let srv = test::server(|| {
App::new().service(web::resource("/").route(web::to(
|req: HttpRequest, pl: web::types::Payload| async move {
ws::start::<_, _, _, web::Error>(
|req: HttpRequest| async move {
ws::start::<_, _, web::Error>(
req,
pl,
fn_factory_with_config(|_| async {
Ok::<_, web::Error>(fn_service(service))
}),
@ -71,10 +69,9 @@ async fn web_ws() {
async fn web_ws_client() {
let srv = test::server(|| {
App::new().service(web::resource("/").route(web::to(
|req: HttpRequest, pl: web::types::Payload| async move {
ws::start::<_, _, _, web::Error>(
|req: HttpRequest| async move {
ws::start::<_, _, web::Error>(
req,
pl,
fn_factory_with_config(|_| async {
Ok::<_, web::Error>(fn_service(service))
}),
@ -89,33 +86,33 @@ async fn web_ws_client() {
assert_eq!(conn.response().status(), StatusCode::SWITCHING_PROTOCOLS);
let sink = conn.sink();
let mut rx = conn.start_default();
let rx = conn.receiver();
sink.send(ws::Message::Text(ByteString::from_static("text")))
.await
.unwrap();
let item = rx.next().await.unwrap().unwrap();
let item = rx.recv().await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text")));
sink.send(ws::Message::Binary("text".into())).await.unwrap();
let item = rx.next().await.unwrap().unwrap();
let item = rx.recv().await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text")));
sink.send(ws::Message::Ping("text".into())).await.unwrap();
let item = rx.next().await.unwrap().unwrap();
let item = rx.recv().await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Pong("text".to_string().into()));
let _on_disconnect = sink.on_disconnect();
let on_disconnect = sink.on_disconnect();
sink.send(ws::Message::Close(Some(ws::CloseCode::Normal.into())))
.await
.unwrap();
let item = rx.next().await.unwrap().unwrap();
let item = rx.recv().await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Away.into())));
let item = rx.next().await;
let item = rx.recv().await;
assert!(item.is_none());
// TODO fix
// on_disconnect.await
on_disconnect.await
}