From 5d9a653f7094081c69eb8a138f092f5f7ee001e6 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sun, 23 Jan 2022 19:56:56 +0600 Subject: [PATCH] Refactor web websockets support (#97) * Refactor ws handling --- ntex-io/CHANGES.md | 4 + ntex-io/Cargo.toml | 4 +- ntex-io/src/dispatcher.rs | 1 - ntex-io/src/io.rs | 50 ++++--- ntex-io/src/ioref.rs | 18 ++- ntex-io/src/seal.rs | 1 + ntex-macros/Cargo.toml | 3 +- ntex-macros/tests/test_macro.rs | 4 - ntex-router/src/de.rs | 4 +- ntex-util/CHANGES.md | 4 + ntex-util/Cargo.toml | 2 +- ntex-util/src/services/mod.rs | 2 - ntex-util/src/services/sink.rs | 82 ----------- ntex-util/src/services/stream.rs | 212 ----------------------------- ntex/CHANGES.md | 4 +- ntex/Cargo.toml | 8 +- ntex/src/http/h1/decoder.rs | 2 + ntex/src/http/h1/dispatcher.rs | 116 ++++++++++++---- ntex/src/http/h2/dispatcher.rs | 7 +- ntex/src/http/message.rs | 78 ++++++++--- ntex/src/http/test.rs | 10 +- ntex/src/web/response.rs | 8 +- ntex/src/web/ws.rs | 167 +++++++++++------------ ntex/src/ws/client.rs | 32 ++--- ntex/src/ws/mod.rs | 2 - ntex/src/ws/sink.rs | 6 + ntex/src/ws/stream.rs | 225 ------------------------------- ntex/tests/web_ws.rs | 27 ++-- 28 files changed, 356 insertions(+), 727 deletions(-) delete mode 100644 ntex-util/src/services/sink.rs delete mode 100644 ntex-util/src/services/stream.rs delete mode 100644 ntex/src/ws/stream.rs diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index 09563f71..9e5b2863 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.1.5] - 2022-01-23 + +* Add Eq,PartialEq,Hash,Debug impls to Io asn IoRef + ## [0.1.4] - 2022-01-17 * Add Io::take() method diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml index cc1f1535..97b34826 100644 --- a/ntex-io/Cargo.toml +++ b/ntex-io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-io" -version = "0.1.4" +version = "0.1.5" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] @@ -16,7 +16,7 @@ name = "ntex_io" path = "src/lib.rs" [dependencies] -ntex-codec = "0.6.0" +ntex-codec = "0.6.1" ntex-bytes = "0.1.9" ntex-util = "0.1.9" ntex-service = "0.3.1" diff --git a/ntex-io/src/dispatcher.rs b/ntex-io/src/dispatcher.rs index dd1b1fbd..04639d30 100644 --- a/ntex-io/src/dispatcher.rs +++ b/ntex-io/src/dispatcher.rs @@ -754,7 +754,6 @@ mod tests { #[ntex::test] async fn test_keepalive() { - env_logger::init(); let (client, server) = IoTest::create(); client.remote_buffer_cap(1024); client.write("GET /test HTTP/1\r\n\r\n"); diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index 50134fdd..86446ed5 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -699,6 +699,39 @@ impl AsRef for Io { } } +impl Eq for Io {} + +impl PartialEq for Io { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl hash::Hash for Io { + #[inline] + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +impl fmt::Debug for Io { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Io") + .field("open", &!self.is_closed()) + .finish() + } +} + +impl Deref for Io { + type Target = IoRef; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl Drop for Io { fn drop(&mut self) { self.remove_keepalive_timer(); @@ -727,23 +760,6 @@ impl Drop for Io { } } -impl fmt::Debug for Io { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Io") - .field("open", &!self.is_closed()) - .finish() - } -} - -impl Deref for Io { - type Target = IoRef; - - #[inline] - fn deref(&self) -> &Self::Target { - &self.0 - } -} - /// OnDisconnect future resolves when socket get disconnected #[must_use = "OnDisconnect do nothing unless polled"] pub struct OnDisconnect { diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index 20640edd..e962f2c9 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -1,4 +1,4 @@ -use std::{any, fmt, io}; +use std::{any, fmt, hash, io}; use ntex_bytes::{BufMut, BytesMut, PoolRef}; use ntex_codec::{Decoder, Encoder}; @@ -190,6 +190,22 @@ impl IoRef { } } +impl Eq for IoRef {} + +impl PartialEq for IoRef { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} + +impl hash::Hash for IoRef { + #[inline] + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + impl fmt::Debug for IoRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("IoRef") diff --git a/ntex-io/src/seal.rs b/ntex-io/src/seal.rs index afa25f04..347cd608 100644 --- a/ntex-io/src/seal.rs +++ b/ntex-io/src/seal.rs @@ -5,6 +5,7 @@ use crate::{Filter, Io}; /// Sealed filter type pub struct Sealed(pub(crate) Box); +#[derive(Debug)] /// Boxed `Io` object with erased filter type pub struct IoBoxed(Io); diff --git a/ntex-macros/Cargo.toml b/ntex-macros/Cargo.toml index 0c87c214..958c1756 100644 --- a/ntex-macros/Cargo.toml +++ b/ntex-macros/Cargo.toml @@ -16,5 +16,6 @@ syn = { version = "^1", features = ["full", "parsing"] } proc-macro2 = "^1" [dev-dependencies] -ntex = "0.5.0-b.0" +ntex = { version = "0.5.0", features = ["tokio"] } futures = "0.3" +env_logger = "0.9" \ No newline at end of file diff --git a/ntex-macros/tests/test_macro.rs b/ntex-macros/tests/test_macro.rs index 0ee4da5a..4d4fdedd 100644 --- a/ntex-macros/tests/test_macro.rs +++ b/ntex-macros/tests/test_macro.rs @@ -115,10 +115,6 @@ async fn test_body() { let response = request.send().await.unwrap(); assert!(response.status().is_success()); - let request = srv.request(Method::CONNECT, srv.url("/test")); - let response = request.send().await.unwrap(); - assert!(response.status().is_success()); - let request = srv.request(Method::OPTIONS, srv.url("/test")); let response = request.send().await.unwrap(); assert!(response.status().is_success()); diff --git a/ntex-router/src/de.rs b/ntex-router/src/de.rs index 8b274c63..12880d5d 100644 --- a/ntex-router/src/de.rs +++ b/ntex-router/src/de.rs @@ -691,12 +691,12 @@ mod tests { #[derive(Debug, Deserialize)] struct S { - inner: (String,), + _inner: (String,), } let s: Result = de::Deserialize::deserialize(PathDeserializer::new(&path)); assert!(s.is_err()); - assert!(format!("{:?}", s).contains("missing field `inner`")); + assert!(format!("{:?}", s).contains("missing field `_inner`")); let path = Path::new(""); let s: Result<&str, de::value::Error> = diff --git a/ntex-util/CHANGES.md b/ntex-util/CHANGES.md index 7a33b5ab..dbe9ca11 100644 --- a/ntex-util/CHANGES.md +++ b/ntex-util/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.1.11] - 2022-01-23 + +* Remove useless stream::Dispatcher and sink::SinkService + ## [0.1.10] - 2022-01-17 * Add time::query_system_time(), it does not use async runtime diff --git a/ntex-util/Cargo.toml b/ntex-util/Cargo.toml index 92980b8e..11ca939e 100644 --- a/ntex-util/Cargo.toml +++ b/ntex-util/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-util" -version = "0.1.10" +version = "0.1.11" authors = ["ntex contributors "] description = "Utilities for ntex framework" keywords = ["network", "framework", "async", "futures"] diff --git a/ntex-util/src/services/mod.rs b/ntex-util/src/services/mod.rs index 44840976..ea1cfede 100644 --- a/ntex-util/src/services/mod.rs +++ b/ntex-util/src/services/mod.rs @@ -3,8 +3,6 @@ pub mod counter; mod extensions; pub mod inflight; pub mod keepalive; -pub mod sink; -pub mod stream; pub mod timeout; pub mod variant; diff --git a/ntex-util/src/services/sink.rs b/ntex-util/src/services/sink.rs deleted file mode 100644 index d43f0bab..00000000 --- a/ntex-util/src/services/sink.rs +++ /dev/null @@ -1,82 +0,0 @@ -use std::{ - cell::Cell, cell::RefCell, marker::PhantomData, pin::Pin, task::Context, task::Poll, -}; - -use futures_sink::Sink; -use ntex_service::Service; - -use crate::future::Ready; - -/// `SinkService` forwards incoming requests to the provided `Sink` -pub struct SinkService { - sink: RefCell, - shutdown: Cell, - _t: PhantomData, -} - -impl SinkService -where - S: Sink + Unpin, -{ - /// Create new `SinnkService` instance - pub fn new(sink: S) -> Self { - SinkService { - sink: RefCell::new(sink), - shutdown: Cell::new(false), - _t: PhantomData, - } - } -} - -impl Clone for SinkService -where - S: Clone, -{ - fn clone(&self) -> Self { - SinkService { - sink: self.sink.clone(), - shutdown: self.shutdown.clone(), - _t: PhantomData, - } - } -} - -impl Service for SinkService -where - S: Sink + Unpin, -{ - type Response = (); - type Error = S::Error; - type Future = Ready<(), S::Error>; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - let mut inner = self.sink.borrow_mut(); - let pending1 = Pin::new(&mut *inner).poll_flush(cx)?.is_pending(); - let pending2 = Pin::new(&mut *inner).poll_ready(cx)?.is_pending(); - if pending1 || pending2 { - Poll::Pending - } else { - Poll::Ready(Ok(())) - } - } - - fn poll_shutdown(&self, cx: &mut Context<'_>, _: bool) -> Poll<()> { - if !self.shutdown.get() { - if Pin::new(&mut *self.sink.borrow_mut()) - .poll_close(cx) - .is_pending() - { - Poll::Pending - } else { - self.shutdown.set(true); - Poll::Ready(()) - } - } else { - Poll::Ready(()) - } - } - - fn call(&self, req: I) -> Self::Future { - Ready::from(Pin::new(&mut *self.sink.borrow_mut()).start_send(req)) - } -} diff --git a/ntex-util/src/services/stream.rs b/ntex-util/src/services/stream.rs deleted file mode 100644 index f63d8fc6..00000000 --- a/ntex-util/src/services/stream.rs +++ /dev/null @@ -1,212 +0,0 @@ -use std::{fmt, future::Future, pin::Pin, task::Context, task::Poll}; - -use log::trace; -use ntex_service::{IntoService, Service}; - -use crate::channel::mpsc; -use crate::{future::poll_fn, Sink, Stream}; - -pin_project_lite::pin_project! { - pub struct Dispatcher - where - R: 'static, - S: Service>, - S: 'static, - T: Stream>, - T: Unpin, - U: Sink>, - U: Unpin, - { - #[pin] - service: S, - stream: T, - sink: Option, - rx: mpsc::Receiver>, - shutdown: Option, - } -} - -impl Dispatcher -where - R: 'static, - S: Service> + 'static, - S::Error: fmt::Debug, - T: Stream> + Unpin, - U: Sink> + Unpin + 'static, - U::Error: fmt::Debug, -{ - pub fn new(stream: T, sink: U, service: F) -> Self - where - F: IntoService, - { - Dispatcher { - stream, - sink: Some(sink), - service: service.into_service(), - rx: mpsc::channel().1, - shutdown: None, - } - } -} - -impl Future for Dispatcher -where - R: 'static, - S: Service> + 'static, - S::Future: 'static, - S::Error: fmt::Debug + 'static, - T: Stream> + Unpin, - U: Sink> + Unpin + 'static, - U::Error: fmt::Debug, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut().project(); - - if let Some(is_err) = this.shutdown { - if let Some(mut sink) = this.sink.take() { - crate::spawn(async move { - if poll_fn(|cx| Pin::new(&mut sink).poll_flush(cx)) - .await - .is_ok() - { - let _ = poll_fn(|cx| Pin::new(&mut sink).poll_close(cx)).await; - } - }); - } - if this.service.poll_shutdown(cx, *is_err).is_pending() { - return Poll::Pending; - } - return Poll::Ready(()); - } - - loop { - match Pin::new(this.sink.as_mut().unwrap()).poll_ready(cx) { - Poll::Pending => { - match Pin::new(this.sink.as_mut().unwrap()).poll_flush(cx) { - Poll::Pending => break, - Poll::Ready(Ok(_)) => (), - Poll::Ready(Err(e)) => { - trace!("Sink flush failed: {:?}", e); - *this.shutdown = Some(true); - return self.poll(cx); - } - } - } - Poll::Ready(Ok(_)) => { - if let Poll::Ready(Some(item)) = Pin::new(&mut this.rx).poll_next(cx) { - match item { - Ok(Some(item)) => { - if let Err(e) = Pin::new(this.sink.as_mut().unwrap()) - .start_send(Ok(item)) - { - trace!("Failed to write to sink: {:?}", e); - *this.shutdown = Some(true); - return self.poll(cx); - } - continue; - } - Ok(None) => continue, - Err(e) => { - trace!("Stream is failed: {:?}", e); - let _ = Pin::new(this.sink.as_mut().unwrap()) - .start_send(Err(e)); - *this.shutdown = Some(true); - return self.poll(cx); - } - } - } - } - Poll::Ready(Err(e)) => { - trace!("Sink readiness check failed: {:?}", e); - *this.shutdown = Some(true); - return self.poll(cx); - } - } - break; - } - - loop { - return match this.service.poll_ready(cx) { - Poll::Ready(Ok(_)) => match Pin::new(&mut this.stream).poll_next(cx) { - Poll::Ready(Some(Ok(item))) => { - let tx = this.rx.sender(); - let fut = this.service.call(item); - crate::spawn(async move { - let res = fut.await; - let _ = tx.send(res); - }); - this = self.as_mut().project(); - continue; - } - Poll::Pending => Poll::Pending, - Poll::Ready(Some(Err(_))) => { - *this.shutdown = Some(true); - return self.poll(cx); - } - Poll::Ready(None) => { - *this.shutdown = Some(false); - return self.poll(cx); - } - }, - Poll::Ready(Err(e)) => { - trace!("Service readiness check failed: {:?}", e); - *this.shutdown = Some(true); - return self.poll(cx); - } - Poll::Pending => Poll::Pending, - }; - } - } -} - -#[cfg(test)] -mod tests { - use std::{cell::Cell, rc::Rc}; - - use ntex::{codec::Encoder, ws}; - use ntex_bytes::{ByteString, BytesMut}; - - use super::*; - use crate::{channel::mpsc, future::stream_recv, time::sleep, time::Millis}; - - #[ntex_macros::rt_test2] - async fn test_basic() { - let counter = Rc::new(Cell::new(0)); - let counter2 = counter.clone(); - - let (tx1, mut rx) = mpsc::channel(); - let (tx, rx2) = mpsc::channel(); - let encoder = ws::StreamEncoder::new(tx1); - let decoder = ws::StreamDecoder::new(rx2); - - let disp = Dispatcher::new( - decoder, - encoder, - ntex_service::fn_service(move |_| { - counter2.set(counter2.get() + 1); - async { Ok(Some(ws::Message::Text(ByteString::from_static("test")))) } - }), - ); - crate::spawn(async move { - let _ = disp.await; - }); - - let mut buf = BytesMut::new(); - let codec = ws::Codec::new().client_mode(); - codec - .encode(ws::Message::Text(ByteString::from_static("test")), &mut buf) - .unwrap(); - tx.send(Ok::<_, ()>(buf.split().freeze())).unwrap(); - - let data = stream_recv(&mut rx).await.unwrap().unwrap(); - assert_eq!(data, b"\x81\x04test".as_ref()); - - drop(tx); - sleep(Millis(10)).await; - assert!(stream_recv(&mut rx).await.is_none()); - - assert_eq!(counter.get(), 1); - } -} diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index caebd113..1518cf4a 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,6 +1,8 @@ # Changes -## [0.5.11] - 2022-01-xx +## [0.5.11] - 2022-01-23 + +* web: Refactor ws support * web: Add types::Payload::recv() and types::Payload::poll_recv() methods diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 9d51e549..884b8c89 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.5.10" +version = "0.5.11" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -52,11 +52,11 @@ ntex-codec = "0.6.1" ntex-router = "0.5.1" ntex-service = "0.3.1" ntex-macros = "0.1.3" -ntex-util = "0.1.9" +ntex-util = "0.1.10" ntex-bytes = "0.1.9" ntex-tls = "0.1.2" -ntex-rt = "0.4.1" -ntex-io = "0.1.4" +ntex-rt = "0.4.3" +ntex-io = "0.1.5" ntex-tokio = "0.1.2" ntex-glommio = { version = "0.1.0", optional = true } ntex-async-std = { version = "0.1.0", optional = true } diff --git a/ntex/src/http/h1/decoder.rs b/ntex/src/http/h1/decoder.rs index 31067c7a..3cd8fc74 100644 --- a/ntex/src/http/h1/decoder.rs +++ b/ntex/src/http/h1/decoder.rs @@ -256,10 +256,12 @@ impl MessageType for Request { PayloadLength::Payload(pl) => pl, PayloadLength::Upgrade => { // upgrade(websocket) + msg.head_mut().set_upgrade(); PayloadType::Stream(PayloadDecoder::eof()) } PayloadLength::None => { if method == Method::CONNECT { + msg.head_mut().set_upgrade(); PayloadType::Stream(PayloadDecoder::eof()) } else { PayloadType::None diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index 65c753be..e1f00a3a 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1,14 +1,15 @@ //! Framed transport dispatcher use std::task::{Context, Poll}; -use std::{error::Error, future::Future, io, marker, pin::Pin, rc::Rc}; +use std::{cell::RefCell, error::Error, future::Future, io, marker, pin::Pin, rc::Rc}; -use crate::io::{Filter, Io, IoRef, RecvError}; +use crate::io::{Filter, Io, IoBoxed, RecvError}; use crate::{service::Service, util::ready, util::Bytes}; use crate::http; use crate::http::body::{BodySize, MessageBody, ResponseBody}; use crate::http::config::DispatcherConfig; use crate::http::error::{DispatchError, ParseError, PayloadError, ResponseError}; +use crate::http::message::CurrentIo; use crate::http::request::Request; use crate::http::response::Response; @@ -26,8 +27,10 @@ bitflags::bitflags! { const KEEPALIVE_REG = 0b0000_0100; /// Upgrade request const UPGRADE = 0b0000_1000; + /// Handling upgrade + const UPGRADE_HND = 0b0001_0000; /// Stop after sending payload - const SENDPAYLOAD_AND_STOP = 0b0001_0000; + const SENDPAYLOAD_AND_STOP = 0b0010_0000; } } @@ -52,6 +55,8 @@ enum State { }, #[display(fmt = "State::Upgrade")] Upgrade(Option), + #[display(fmt = "State::StopIo")] + StopIo(Box<(IoBoxed, Codec)>), Stop, } @@ -60,6 +65,7 @@ pin_project_lite::pin_project! { enum CallState, X: Service> { None, Service { #[pin] fut: S::Future }, + ServiceUpgrade { #[pin] fut: S::Future }, Expect { #[pin] fut: X::Future }, Filter { fut: Pin>>> } } @@ -69,7 +75,6 @@ struct DispatcherInner { io: Io, flags: Flags, codec: Codec, - state: IoRef, config: Rc>, error: Option, payload: Option<(PayloadDecoder, PayloadSender)>, @@ -89,7 +94,6 @@ where { /// Construct new `Dispatcher` instance with outgoing messages stream. pub(in crate::http) fn new(io: Io, config: Rc>) -> Self { - let state = io.get_ref(); let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled()); io.set_disconnect_timeout(config.client_disconnect.into()); @@ -102,7 +106,6 @@ where inner: DispatcherInner { io, codec, - state, config, flags: Flags::KEEPALIVE_REG, error: None, @@ -113,11 +116,6 @@ where } } -macro_rules! set_error ({ $slf:tt, $err:ident } => { - *$slf.st = State::Stop; - $slf.inner.error = Some($err); -}); - impl Future for Dispatcher where F: Filter, @@ -143,7 +141,7 @@ where Poll::Ready(result) => match result { Ok(res) => { let (res, body) = res.into().into_parts(); - *this.st = this.inner.send_response(res, body) + *this.st = this.inner.send_response(res, body); } Err(e) => *this.st = this.inner.handle_error(e, false), }, @@ -154,7 +152,8 @@ where if let Err(e) = ready!(this.inner.poll_request_payload(cx)) { - set_error!(this, e); + *this.st = State::Stop; + this.inner.error = Some(e); } } else { return Poll::Pending; @@ -163,6 +162,43 @@ where } None } + // special handling for upgrade requests. + // we cannot continue to handle requests, because Io get + // converted to IoBoxed before we set it to request, + // so we have to send response and disconnect. request payload + // handling should be handled by service + CallStateProject::ServiceUpgrade { fut } => { + let result = ready!(fut.poll(cx)); + match result { + Ok(res) => { + let (msg, body) = res.into().into_parts(); + let item = if let Some(item) = msg.head().take_io() { + item + } else { + return Poll::Ready(Ok(())); + }; + + let _ = item + .0 + .encode(Message::Item((msg, body.size())), &item.1); + match body.size() { + BodySize::None | BodySize::Empty => {} + _ => { + log::error!("Stream responses are not supported for upgrade requests"); + } + } + *this.st = State::StopIo(item); + } + Err(e) => { + log::error!( + "Cannot handle error for upgrade handler: {:?}", + e + ); + return Poll::Ready(Ok(())); + } + } + None + } // handle EXPECT call // expect service call must resolve before // we can do any more io processing. @@ -170,7 +206,7 @@ where // TODO: check keep-alive timer interaction CallStateProject::Expect { fut } => match ready!(fut.poll(cx)) { Ok(req) => { - let result = this.inner.state.with_write_buf(|buf| { + let result = this.inner.io.with_write_buf(|buf| { buf.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n") }); if result.is_err() { @@ -181,6 +217,11 @@ where *this.st = State::Upgrade(Some(req)); this = self.as_mut().project(); continue; + } else if this.inner.flags.contains(Flags::UPGRADE_HND) { + // Handle upgrade requests + Some(CallState::ServiceUpgrade { + fut: this.inner.config.service.call(req), + }) } else { Some(CallState::Service { fut: this.inner.config.service.call(req), @@ -204,6 +245,12 @@ where Some(CallState::Expect { fut: this.inner.config.expect.call(req), }) + } else if this.inner.flags.contains(Flags::UPGRADE_HND) + { + // Handle upgrade requests + Some(CallState::ServiceUpgrade { + fut: this.inner.config.service.call(req), + }) } else { // Handle normal requests Some(CallState::Service { @@ -238,7 +285,6 @@ where req, pl ); - req.head_mut().io = Some(this.inner.state.clone()); // configure request payload let upgrade = match pl { @@ -272,18 +318,38 @@ where log::trace!("prep io for upgrade handler"); *this.st = State::Upgrade(Some(req)); } else { + if req.upgrade() { + this.inner.flags.insert(Flags::UPGRADE_HND); + let io: IoBoxed = this.inner.io.take().into(); + req.head_mut().io = CurrentIo::Io(Rc::new(( + io.get_ref(), + RefCell::new(Some(Box::new(( + io, + this.inner.codec.clone(), + )))), + ))); + } else { + req.head_mut().io = + CurrentIo::Ref(this.inner.io.get_ref()); + } *this.st = State::Call; this.call.set( if let Some(ref f) = this.inner.config.on_request { // Handle filter fut CallState::Filter { - fut: f.call((req, this.inner.state.clone())), + fut: f.call((req, this.inner.io.get_ref())), } } else if req.head().expect() { // Handle normal requests with EXPECT: 100-Continue` header CallState::Expect { fut: this.inner.config.expect.call(req), } + } else if this.inner.flags.contains(Flags::UPGRADE_HND) + { + // Handle upgrade requests + CallState::ServiceUpgrade { + fut: this.inner.config.service.call(req), + } } else { // Handle normal requests CallState::Service { @@ -401,6 +467,10 @@ where Poll::Ready(Ok(())) }; } + // prepare to shutdown + State::StopIo(ref item) => { + return item.0.poll_shutdown(cx).map_err(From::from) + } } } } @@ -416,7 +486,7 @@ where fn switch_to_read_request(&mut self) -> State { // connection is not keep-alive, disconnect if !self.flags.contains(Flags::KEEPALIVE) || !self.codec.keepalive_enabled() { - self.state.close(); + self.io.close(); State::Stop } else { State::ReadRequest @@ -457,7 +527,7 @@ where // we dont need to process responses if socket is disconnected // but we still want to handle requests with app service // so we skip response processing for droppped connection - if self.state.is_closed() { + if self.io.is_closed() { State::Stop } else { let result = self @@ -751,14 +821,14 @@ mod tests { sleep(Millis(50)).await; assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); - assert!(h1.inner.state.is_closed()); + assert!(h1.inner.io.is_closed()); sleep(Millis(50)).await; client.local_buffer(|buf| assert_eq!(&buf[..26], b"HTTP/1.1 400 Bad Request\r\n")); client.close().await; assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); - assert!(h1.inner.state.is_closed()); + assert!(h1.inner.io.is_closed()); } #[crate::rt_test] @@ -916,7 +986,7 @@ mod tests { let _ = lazy(|cx| Pin::new(&mut h1).poll(cx)).await; sleep(Millis(550)).await; assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); - assert!(h1.inner.state.is_closed()); + assert!(h1.inner.io.is_closed()); let mut buf = client.read().await.unwrap(); assert_eq!(load(&mut decoder, &mut buf).status, StatusCode::BAD_REQUEST); @@ -990,7 +1060,7 @@ mod tests { Ok::<_, io::Error>(Response::Ok().message_body(Stream(n.clone()))) }) }); - let state = h1.inner.state.clone(); + let state = h1.inner.io.get_ref(); // do not allow to write to socket client.remote_buffer_cap(0); @@ -1084,7 +1154,7 @@ mod tests { assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); sleep(Millis(50)).await; - assert!(h1.inner.state.is_closed()); + assert!(h1.inner.io.is_closed()); let buf = client.local_buffer(|buf| buf.split().freeze()); assert_eq!(&buf[..28], b"HTTP/1.1 500 Internal Server"); assert_eq!(&buf[buf.len() - 5..], b"error"); diff --git a/ntex/src/http/h2/dispatcher.rs b/ntex/src/http/h2/dispatcher.rs index 82a5411d..90135bf4 100644 --- a/ntex/src/http/h2/dispatcher.rs +++ b/ntex/src/http/h2/dispatcher.rs @@ -11,9 +11,8 @@ use crate::http::error::{DispatchError, ResponseError}; use crate::http::header::{ HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, }; -use crate::http::{ - message::ResponseHead, payload::Payload, request::Request, response::Response, -}; +use crate::http::message::{CurrentIo, ResponseHead}; +use crate::http::{payload::Payload, request::Request, response::Response}; use crate::io::{IoRef, TokioIoBoxed}; use crate::service::Service; use crate::time::{now, Sleep}; @@ -105,7 +104,7 @@ where head.method = parts.method; head.version = parts.version; head.headers = parts.headers.into(); - head.io = Some(this.io.clone()); + head.io = CurrentIo::Ref(this.io.clone()); crate::rt::spawn(ServiceResponse { state: ServiceResponseState::ServiceCall { diff --git a/ntex/src/http/message.rs b/ntex/src/http/message.rs index c5b956d0..a5ee6a3f 100644 --- a/ntex/src/http/message.rs +++ b/ntex/src/http/message.rs @@ -3,8 +3,8 @@ use std::{cell::Ref, cell::RefCell, cell::RefMut, net, rc::Rc}; use bitflags::bitflags; use crate::http::header::HeaderMap; -use crate::http::{header, Method, StatusCode, Uri, Version}; -use crate::io::{types, IoRef}; +use crate::http::{h1::Codec, Method, StatusCode, Uri, Version}; +use crate::io::{types, IoBoxed, IoRef}; use crate::util::Extensions; /// Represents various types of connection @@ -28,7 +28,6 @@ bitflags! { } } -#[doc(hidden)] pub(crate) trait Head: Default + 'static { fn clear(&mut self); @@ -37,6 +36,23 @@ pub(crate) trait Head: Default + 'static { F: FnOnce(&MessagePool) -> R; } +#[derive(Clone, Debug)] +pub(crate) enum CurrentIo { + Ref(IoRef), + Io(Rc<(IoRef, RefCell>>)>), + None, +} + +impl CurrentIo { + pub(crate) fn as_ref(&self) -> Option<&IoRef> { + match self { + CurrentIo::Ref(ref io) => Some(io), + CurrentIo::Io(ref io) => Some(&io.0), + CurrentIo::None => None, + } + } +} + #[derive(Debug)] pub struct RequestHead { pub uri: Uri, @@ -44,14 +60,14 @@ pub struct RequestHead { pub version: Version, pub headers: HeaderMap, pub extensions: RefCell, - pub io: Option, + pub(crate) io: CurrentIo, pub(crate) flags: Flags, } impl Default for RequestHead { fn default() -> RequestHead { RequestHead { - io: None, + io: CurrentIo::None, uri: Uri::default(), method: Method::default(), version: Version::HTTP_11, @@ -64,7 +80,7 @@ impl Default for RequestHead { impl Head for RequestHead { fn clear(&mut self) { - self.io = None; + self.io = CurrentIo::None; self.flags = Flags::empty(); self.headers.clear(); self.extensions.get_mut().clear(); @@ -127,17 +143,16 @@ impl RequestHead { } } + #[inline] /// Connection upgrade status pub fn upgrade(&self) -> bool { - if let Some(hdr) = self.headers().get(header::CONNECTION) { - if let Ok(s) = hdr.to_str() { - s.to_ascii_lowercase().contains("upgrade") - } else { - false - } - } else { - false - } + self.flags.contains(Flags::UPGRADE) + } + + #[inline] + /// Request contains `EXPECT` header + pub fn expect(&self) -> bool { + self.flags.contains(Flags::EXPECT) } #[inline] @@ -156,14 +171,13 @@ impl RequestHead { } #[inline] - /// Request contains `EXPECT` header - pub fn expect(&self) -> bool { - self.flags.contains(Flags::EXPECT) + pub(crate) fn set_expect(&mut self) { + self.flags.insert(Flags::EXPECT); } #[inline] - pub(crate) fn set_expect(&mut self) { - self.flags.insert(Flags::EXPECT); + pub(crate) fn set_upgrade(&mut self) { + self.flags.insert(Flags::UPGRADE); } /// Peer socket address @@ -178,6 +192,16 @@ impl RequestHead { .map(types::PeerAddr::into_inner) }) } + + /// Take io and codec for current request + /// + /// This objects are set only for upgrade requests + pub fn take_io(&self) -> Option> { + match self.io { + CurrentIo::Io(ref inner) => inner.1.borrow_mut().take(), + _ => None, + } + } } #[derive(Debug)] @@ -216,6 +240,7 @@ pub struct ResponseHead { pub status: StatusCode, pub headers: HeaderMap, pub reason: Option<&'static str>, + pub(crate) io: CurrentIo, pub(crate) extensions: RefCell, flags: Flags, } @@ -230,6 +255,7 @@ impl ResponseHead { headers: HeaderMap::with_capacity(12), reason: None, flags: Flags::empty(), + io: CurrentIo::None, extensions: RefCell::new(Extensions::new()), } } @@ -335,6 +361,17 @@ impl ResponseHead { self.flags.remove(Flags::NO_CHUNKING); } } + + pub(crate) fn set_io(&mut self, head: &RequestHead) { + self.io = head.io.clone(); + } + + pub(crate) fn take_io(&self) -> Option> { + match self.io { + CurrentIo::Io(ref inner) => inner.1.borrow_mut().take(), + _ => None, + } + } } impl Default for ResponseHead { @@ -347,6 +384,7 @@ impl Head for ResponseHead { fn clear(&mut self) { self.reason = None; self.headers.clear(); + self.io = CurrentIo::None; self.flags = Flags::empty(); } diff --git a/ntex/src/http/test.rs b/ntex/src/http/test.rs index 478f496e..9295f324 100644 --- a/ntex/src/http/test.rs +++ b/ntex/src/http/test.rs @@ -10,7 +10,7 @@ use crate::{time::Millis, time::Seconds, util::Bytes}; use super::client::{Client, ClientRequest, ClientResponse, Connector}; use super::error::{HttpError, PayloadError}; -use super::header::{HeaderMap, HeaderName, HeaderValue}; +use super::header::{self, HeaderMap, HeaderName, HeaderValue}; use super::payload::Payload; use super::{Method, Request, Uri, Version}; @@ -148,6 +148,14 @@ impl TestRequest { head.version = inner.version; head.headers = inner.headers; + if let Some(conn) = head.headers.get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + if s.to_lowercase().contains("upgrade") { + head.set_upgrade() + } + } + } + #[cfg(feature = "cookie")] { use percent_encoding::percent_encode; diff --git a/ntex/src/web/response.rs b/ntex/src/web/response.rs index 5ce09ecc..3466ce44 100644 --- a/ntex/src/web/response.rs +++ b/ntex/src/web/response.rs @@ -108,9 +108,7 @@ impl WebResponse { pub fn take_body(&mut self) -> ResponseBody { self.response.take_body() } -} -impl WebResponse { /// Set a new body pub fn map_body(self, f: F) -> WebResponse where @@ -126,7 +124,11 @@ impl WebResponse { } impl From for Response { - fn from(res: WebResponse) -> Response { + fn from(mut res: WebResponse) -> Response { + let head = res.response.head_mut(); + if head.upgrade() { + head.set_io(res.request.head()); + } res.response } } diff --git a/ntex/src/web/ws.rs b/ntex/src/web/ws.rs index f24b670d..4d0ce708 100644 --- a/ntex/src/web/ws.rs +++ b/ntex/src/web/ws.rs @@ -1,111 +1,106 @@ -use std::{error, marker::PhantomData, pin::Pin, task::Context, task::Poll}; +//! WebSockets protocol support +use std::fmt; -pub use crate::ws::{CloseCode, CloseReason, Frame, Message}; +pub use crate::ws::{CloseCode, CloseReason, Frame, Message, WsSink}; -use crate::http::body::{Body, BoxedBodyStream}; -use crate::http::error::PayloadError; -use crate::service::{IntoServiceFactory, Service, ServiceFactory}; +use crate::http::{body::BodySize, h1, StatusCode}; +use crate::service::{ + apply_fn, fn_factory_with_config, IntoServiceFactory, Service, ServiceFactory, +}; use crate::web::{HttpRequest, HttpResponse}; -use crate::ws::{error::HandshakeError, handshake}; -use crate::{channel::mpsc, rt, util::Bytes, util::Sink, util::Stream, ws}; +use crate::ws::{error::HandshakeError, error::WsError, handshake}; +use crate::{io::DispatchItem, rt, util::Either, util::Ready, ws}; -pub type WebSocketsSink = - ws::StreamEncoder>>>; - -// TODO: fix close frame handling /// Do websocket handshake and start websockets service. -pub async fn start( - req: HttpRequest, - payload: S, - factory: F, -) -> Result +pub async fn start(req: HttpRequest, factory: F) -> Result where - T: ServiceFactory> + 'static, - T::Error: error::Error, - F: IntoServiceFactory, - S: Stream> + Unpin + 'static, + T: ServiceFactory> + 'static, + T::Error: fmt::Debug, + F: IntoServiceFactory, Err: From + From, { - let (tx, rx) = mpsc::channel(); + let inner_factory = factory.into_factory().map_err(WsError::Service); - start_with(req, payload, tx, rx, factory).await + let factory = fn_factory_with_config(move |sink: WsSink| { + let fut = inner_factory.new_service(sink.clone()); + + async move { + let srv = fut.await?; + Ok::<_, T::InitError>(apply_fn(srv, move |req, srv| match req { + DispatchItem::Item(item) => { + let s = if matches!(item, Frame::Close(_)) { + Some(sink.clone()) + } else { + None + }; + let fut = srv.call(item); + Either::Left(async move { + let result = fut.await; + if let Some(s) = s { + rt::spawn(async move { s.io().close() }); + } + result + }) + } + DispatchItem::WBackPressureEnabled + | DispatchItem::WBackPressureDisabled => Either::Right(Ready::Ok(None)), + DispatchItem::KeepAliveTimeout => { + Either::Right(Ready::Err(WsError::KeepAlive)) + } + DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => { + Either::Right(Ready::Err(WsError::Protocol(e))) + } + DispatchItem::Disconnect(e) => { + Either::Right(Ready::Err(WsError::Disconnected(e))) + } + })) + } + }); + + start_with(req, factory).await } /// Do websocket handshake and start websockets service. -pub async fn start_with( +pub async fn start_with( req: HttpRequest, - payload: S, - tx: Tx, - rx: Rx, factory: F, ) -> Result where - T: ServiceFactory, Response = Option> + 'static, - T::Error: error::Error, - F: IntoServiceFactory>, - S: Stream> + Unpin + 'static, + T: ServiceFactory, WsSink, Response = Option> + + 'static, + T::Error: fmt::Debug, + F: IntoServiceFactory, WsSink>, Err: From + From, - Tx: Sink>> + Clone + Unpin + 'static, - Tx::Error: error::Error, - Rx: Stream>> + Unpin + 'static, { - // ws handshake - let mut res = handshake(req.head())?; + log::trace!("Start ws handshake verification for {:?}", req.path()); - // converter wraper from ws::Message to Bytes - let sink = ws::StreamEncoder::new(tx); + // ws handshake + let res = handshake(req.head())?.finish().into_parts().0; + + // extract io + let item = req + .head() + .take_io() + .ok_or(HandshakeError::NoWebsocketUpgrade)?; + let io = item.0; + let codec = item.1; + + io.encode(h1::Message::Item((res, BodySize::Empty)), &codec) + .map_err(|_| HandshakeError::NoWebsocketUpgrade)?; + log::trace!("Ws handshake verification completed for {:?}", req.path()); + + // create sink + let codec = ws::Codec::new(); + let sink = WsSink::new(io.get_ref(), codec.clone()); // create ws service - let srv = factory - .into_factory() - .new_service(sink.clone()) - .await? - .map_err(|e| { - let e: Box = Box::new(e); - e - }); + let srv = factory.into_factory().new_service(sink).await?; // start websockets service dispatcher - rt::spawn(crate::util::stream::Dispatcher::new( - // wrap bytes stream to ws::Frame's stream - MapStream { - stream: ws::StreamDecoder::new(payload), - _t: PhantomData, - }, - // converter wraper from ws::Message to Bytes - sink, - // websockets handler service - srv, - )); + rt::spawn(async move { + let res = crate::io::Dispatcher::new(io, codec, srv).await; + log::trace!("Ws handler is terminated: {:?}", res); + }); - Ok(res.body(Body::from_message(BoxedBodyStream::new(rx)))) -} - -pin_project_lite::pin_project! { - struct MapStream{ - #[pin] - stream: S, - _t: PhantomData<(I, E)>, - } -} - -impl Stream for MapStream -where - S: Stream>, - E: error::Error + 'static, -{ - type Item = Result>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project().stream.poll_next(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Some(Ok(item))) => Poll::Ready(Some(Ok(item))), - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(Box::new(err)))), - Poll::Ready(None) => Poll::Ready(None), - } - } - - fn size_hint(&self) -> (usize, Option) { - self.stream.size_hint() - } + Ok(HttpResponse::new(StatusCode::OK)) } diff --git a/ntex/src/ws/client.rs b/ntex/src/ws/client.rs index ad536cbd..eeaa56f4 100644 --- a/ntex/src/ws/client.rs +++ b/ntex/src/ws/client.rs @@ -17,7 +17,7 @@ use crate::http::{body::BodySize, client::ClientResponse, error::HttpError, h1}; use crate::http::{ConnectionType, RequestHead, RequestHeadType, StatusCode, Uri}; use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Sealed}; use crate::service::{apply_fn, into_service, IntoService, Service}; -use crate::util::{sink, Either, Ready}; +use crate::util::{Either, Ready}; use crate::{channel::mpsc, rt, time::timeout, time::Millis, ws}; use super::error::{WsClientBuilderError, WsClientError, WsError}; @@ -695,29 +695,25 @@ impl WsConnection { impl WsConnection { // TODO: fix close frame handling /// Start client websockets with `SinkService` and `mpsc::Receiver` - pub fn start_default(self) -> mpsc::Receiver>> { + pub fn receiver(self) -> mpsc::Receiver>> { let (tx, rx): (_, mpsc::Receiver>>) = mpsc::channel(); rt::spawn(async move { + let tx2 = tx.clone(); let io = self.io.get_ref(); - let srv = sink::SinkService::new(tx.clone()).map(|_| None); - if let Err(err) = self - .start(into_service(move |item| { - let io = io.clone(); - let close = matches!(item, ws::Frame::Close(_)); - let fut = srv.call(Ok::<_, WsError<()>>(item)); - async move { - let result = fut.await.map_err(|_| ()); - if close { - io.close(); - } - result - } + let result = self + .start(into_service(move |item: ws::Frame| { + match tx.send(Ok(item)) { + Ok(()) => (), + Err(_) => io.close(), + }; + Ready::Ok::, ()>(None) })) - .await - { - let _ = tx.send(Err(err)); + .await; + + if let Err(e) = result { + let _ = tx2.send(Err(e)); } }); diff --git a/ntex/src/ws/mod.rs b/ntex/src/ws/mod.rs index 3eddb99d..4a395c2d 100644 --- a/ntex/src/ws/mod.rs +++ b/ntex/src/ws/mod.rs @@ -10,7 +10,6 @@ mod handshake; mod mask; mod proto; mod sink; -mod stream; mod transport; pub mod error; @@ -21,5 +20,4 @@ pub use self::frame::Parser; pub use self::handshake::{handshake, handshake_response, verify_handshake}; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; pub use self::sink::WsSink; -pub use self::stream::{StreamDecoder, StreamEncoder}; pub use self::transport::{WsTransport, WsTransportFactory}; diff --git a/ntex/src/ws/sink.rs b/ntex/src/ws/sink.rs index b3b8a248..7744662b 100644 --- a/ntex/src/ws/sink.rs +++ b/ntex/src/ws/sink.rs @@ -3,6 +3,7 @@ use std::{future::Future, rc::Rc}; use crate::io::{IoRef, OnDisconnect}; use crate::ws; +#[derive(Clone)] pub struct WsSink(Rc); struct WsSinkInner { @@ -15,6 +16,11 @@ impl WsSink { Self(Rc::new(WsSinkInner { io, codec })) } + /// Io reference + pub fn io(&self) -> &IoRef { + &self.0.io + } + /// Endcode and send message to the peer. pub fn send( &self, diff --git a/ntex/src/ws/stream.rs b/ntex/src/ws/stream.rs deleted file mode 100644 index 959a583f..00000000 --- a/ntex/src/ws/stream.rs +++ /dev/null @@ -1,225 +0,0 @@ -use std::{ - cell::RefCell, fmt, marker::PhantomData, pin::Pin, rc::Rc, task::Context, task::Poll, -}; - -use super::{error::ProtocolError, Codec, Frame, Message}; -use crate::util::{Bytes, BytesMut, Sink, Stream}; -use crate::{codec::Decoder, codec::Encoder}; - -/// Stream error -#[derive(Debug, Display)] -pub enum StreamError { - #[display(fmt = "StreamError::Stream({:?})", _0)] - Stream(E), - Protocol(ProtocolError), -} - -impl std::error::Error for StreamError {} - -impl From for StreamError { - fn from(err: ProtocolError) -> Self { - StreamError::Protocol(err) - } -} - -pin_project_lite::pin_project! { - /// Stream ws protocol decoder. - pub struct StreamDecoder { - #[pin] - stream: S, - codec: Codec, - buf: BytesMut, - _t: PhantomData, - } -} - -impl StreamDecoder { - pub fn new(stream: S) -> Self { - StreamDecoder::with(stream, Codec::new()) - } - - pub fn with(stream: S, codec: Codec) -> Self { - StreamDecoder { - stream, - codec, - buf: BytesMut::new(), - _t: PhantomData, - } - } -} - -impl Stream for StreamDecoder -where - S: Stream>, - E: fmt::Debug, -{ - type Item = Result>; - - #[inline] - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let mut this = self.as_mut().project(); - - loop { - if !this.buf.is_empty() { - match this.codec.decode(this.buf) { - Ok(Some(item)) => return Poll::Ready(Some(Ok(item))), - Ok(None) => (), - Err(err) => return Poll::Ready(Some(Err(err.into()))), - } - } - - match this.stream.poll_next(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Some(Ok(buf))) => { - this.buf.extend(&buf); - this = self.as_mut().project(); - } - Poll::Ready(Some(Err(err))) => { - return Poll::Ready(Some(Err(StreamError::Stream(err)))) - } - Poll::Ready(None) => return Poll::Ready(None), - } - } - } -} - -pin_project_lite::pin_project! { - /// Stream ws protocol decoder. - #[derive(Clone)] - pub struct StreamEncoder { - #[pin] - sink: S, - codec: Rc>, - } -} - -impl StreamEncoder { - pub fn new(sink: S) -> Self { - StreamEncoder::with(sink, Codec::new()) - } - - pub fn with(sink: S, codec: Codec) -> Self { - StreamEncoder { - sink, - codec: Rc::new(RefCell::new(codec)), - } - } -} - -impl Sink> for StreamEncoder -where - S: Sink>, - S::Error: fmt::Debug, -{ - type Error = StreamError; - - #[inline] - fn poll_ready( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.project() - .sink - .poll_ready(cx) - .map_err(StreamError::Stream) - } - - fn start_send( - self: Pin<&mut Self>, - item: Result, - ) -> Result<(), Self::Error> { - let this = self.project(); - - match item { - Ok(item) => { - let mut buf = BytesMut::new(); - this.codec.borrow_mut().encode(item, &mut buf)?; - this.sink - .start_send(Ok(buf.freeze())) - .map_err(StreamError::Stream) - } - Err(e) => this.sink.start_send(Err(e)).map_err(StreamError::Stream), - } - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.project() - .sink - .poll_flush(cx) - .map_err(StreamError::Stream) - } - - fn poll_close( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.project() - .sink - .poll_close(cx) - .map_err(StreamError::Stream) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - channel::mpsc, util::poll_fn, util::send, util::stream_recv, util::ByteString, - }; - - #[crate::rt_test] - async fn test_decoder() { - let (tx, rx) = mpsc::channel(); - let mut decoder = StreamDecoder::new(rx); - - let mut buf = BytesMut::new(); - let codec = Codec::new().client_mode(); - codec - .encode(Message::Text(ByteString::from_static("test1")), &mut buf) - .unwrap(); - codec - .encode(Message::Text(ByteString::from_static("test2")), &mut buf) - .unwrap(); - - tx.send(Ok::<_, ()>(buf.split().freeze())).unwrap(); - let frame = stream_recv(&mut decoder).await.unwrap().unwrap(); - match frame { - Frame::Text(data) => assert_eq!(data, b"test1"[..]), - _ => panic!(), - } - let frame = stream_recv(&mut decoder).await.unwrap().unwrap(); - match frame { - Frame::Text(data) => assert_eq!(data, b"test2"[..]), - _ => panic!(), - } - } - - #[crate::rt_test] - async fn test_encoder() { - let (tx, mut rx) = mpsc::channel(); - let mut encoder = StreamEncoder::new(tx); - - send( - &mut encoder, - Ok::<_, ()>(Message::Text(ByteString::from_static("test"))), - ) - .await - .unwrap(); - poll_fn(|cx| Pin::new(&mut encoder).poll_flush(cx)) - .await - .unwrap(); - poll_fn(|cx| Pin::new(&mut encoder).poll_close(cx)) - .await - .unwrap(); - - let data = stream_recv(&mut rx).await.unwrap().unwrap(); - assert_eq!(data, b"\x81\x04test".as_ref()); - assert!(stream_recv(&mut rx).await.is_none()); - } -} diff --git a/ntex/tests/web_ws.rs b/ntex/tests/web_ws.rs index 6c926374..9fd37ab1 100644 --- a/ntex/tests/web_ws.rs +++ b/ntex/tests/web_ws.rs @@ -1,6 +1,5 @@ use std::io; -use futures_util::StreamExt; use ntex::http::StatusCode; use ntex::service::{fn_factory_with_config, fn_service}; use ntex::util::{ByteString, Bytes}; @@ -23,10 +22,9 @@ async fn service(msg: ws::Frame) -> Result, io::Error> { async fn web_ws() { let srv = test::server(|| { App::new().service(web::resource("/").route(web::to( - |req: HttpRequest, pl: web::types::Payload| async move { - ws::start::<_, _, _, web::Error>( + |req: HttpRequest| async move { + ws::start::<_, _, web::Error>( req, - pl, fn_factory_with_config(|_| async { Ok::<_, web::Error>(fn_service(service)) }), @@ -71,10 +69,9 @@ async fn web_ws() { async fn web_ws_client() { let srv = test::server(|| { App::new().service(web::resource("/").route(web::to( - |req: HttpRequest, pl: web::types::Payload| async move { - ws::start::<_, _, _, web::Error>( + |req: HttpRequest| async move { + ws::start::<_, _, web::Error>( req, - pl, fn_factory_with_config(|_| async { Ok::<_, web::Error>(fn_service(service)) }), @@ -89,33 +86,33 @@ async fn web_ws_client() { assert_eq!(conn.response().status(), StatusCode::SWITCHING_PROTOCOLS); let sink = conn.sink(); - let mut rx = conn.start_default(); + let rx = conn.receiver(); sink.send(ws::Message::Text(ByteString::from_static("text"))) .await .unwrap(); - let item = rx.next().await.unwrap().unwrap(); + let item = rx.recv().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); sink.send(ws::Message::Binary("text".into())).await.unwrap(); - let item = rx.next().await.unwrap().unwrap(); + let item = rx.recv().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); sink.send(ws::Message::Ping("text".into())).await.unwrap(); - let item = rx.next().await.unwrap().unwrap(); + let item = rx.recv().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Pong("text".to_string().into())); - let _on_disconnect = sink.on_disconnect(); + let on_disconnect = sink.on_disconnect(); sink.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))) .await .unwrap(); - let item = rx.next().await.unwrap().unwrap(); + let item = rx.recv().await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Away.into()))); - let item = rx.next().await; + let item = rx.recv().await; assert!(item.is_none()); // TODO fix - // on_disconnect.await + on_disconnect.await }