diff --git a/ntex-service/src/lib.rs b/ntex-service/src/lib.rs index 0e7d5e6e..2e07a8be 100644 --- a/ntex-service/src/lib.rs +++ b/ntex-service/src/lib.rs @@ -3,7 +3,6 @@ use std::future::Future; use std::rc::Rc; -use std::sync::Arc; use std::task::{self, Context, Poll}; mod and_then; @@ -190,6 +189,11 @@ pub trait ServiceFactory { type Future: Future>; /// Create and return a new service value asynchronously. + fn create(&self, cfg: Self::Config) -> Self::Future { + self.new_service(cfg) + } + + #[doc(hidden)] fn new_service(&self, cfg: Self::Config) -> Self::Future; #[inline] @@ -224,7 +228,7 @@ pub trait ServiceFactory { } } -impl<'a, S> Service for &'a mut S +impl<'a, S> Service for &'a S where S: Service + 'a, { @@ -314,23 +318,6 @@ where } } -impl ServiceFactory for Arc -where - S: ServiceFactory, -{ - type Request = S::Request; - type Response = S::Response; - type Error = S::Error; - type Config = S::Config; - type Service = S::Service; - type InitError = S::InitError; - type Future = S::Future; - - fn new_service(&self, cfg: S::Config) -> S::Future { - self.as_ref().new_service(cfg) - } -} - /// Trait for types that can be converted to a `Service` pub trait IntoService where diff --git a/ntex/src/channel/mpsc.rs b/ntex/src/channel/mpsc.rs index 7175645d..461514f5 100644 --- a/ntex/src/channel/mpsc.rs +++ b/ntex/src/channel/mpsc.rs @@ -58,7 +58,6 @@ impl Sender { /// This prevents any further messages from being sent on the channel while /// still enabling the receiver to drain messages that are buffered. pub fn close(&self) { - println!("Close mpsc"); let shared = self.shared.get_mut(); shared.has_receiver = false; shared.blocked_recv.wake(); @@ -98,6 +97,7 @@ impl Sink for Sender { self: Pin<&mut Self>, _: &mut Context<'_>, ) -> Poll> { + self.close(); Poll::Ready(Ok(())) } } @@ -195,11 +195,14 @@ impl SendError { mod tests { use super::*; use futures::future::lazy; - use futures::{Stream, StreamExt}; + use futures::{Sink, Stream, StreamExt}; #[ntex_rt::test] async fn test_mpsc() { let (tx, mut rx) = channel(); + assert!(format!("{:?}", tx).contains("Sender")); + assert!(format!("{:?}", rx).contains("Receiver")); + tx.send("test").unwrap(); assert_eq!(rx.next().await.unwrap(), "test"); @@ -238,4 +241,18 @@ mod tests { assert!(format!("{}", err).contains("send failed because receiver is gone")); assert_eq!(err.into_inner(), "test"); } + + #[ntex_rt::test] + async fn test_sink() { + let (mut tx, mut rx) = channel(); + lazy(|cx| { + assert!(Pin::new(&mut tx).poll_ready(cx).is_ready()); + assert!(Pin::new(&mut tx).start_send("test").is_ok()); + assert!(Pin::new(&mut tx).poll_flush(cx).is_ready()); + assert!(Pin::new(&mut tx).poll_close(cx).is_ready()); + }) + .await; + assert_eq!(rx.next().await.unwrap(), "test"); + assert_eq!(rx.next().await, None); + } } diff --git a/ntex/src/channel/oneshot.rs b/ntex/src/channel/oneshot.rs index a02ced82..5e0e7e04 100644 --- a/ntex/src/channel/oneshot.rs +++ b/ntex/src/channel/oneshot.rs @@ -284,11 +284,13 @@ mod tests { #[ntex_rt::test] async fn test_pool() { - let (tx, rx) = pool().channel(); + let p = pool(); + let (tx, rx) = p.channel(); tx.send("test").unwrap(); assert_eq!(rx.await.unwrap(), "test"); - let (tx, rx) = pool().channel(); + let p2 = p.clone(); + let (tx, rx) = p2.channel(); assert!(!tx.is_canceled()); drop(rx); assert!(tx.is_canceled()); diff --git a/ntex/src/framed/service.rs b/ntex/src/framed/service.rs index edb02ce7..a17acd6c 100644 --- a/ntex/src/framed/service.rs +++ b/ntex/src/framed/service.rs @@ -462,6 +462,9 @@ where > { #[project] match self.project() { + FramedServiceImplResponseInner::Dispatcher(ref mut fut) => { + Either::Right(fut.poll_inner(cx)) + } FramedServiceImplResponseInner::Handshake(fut, handler, timeout) => { match fut.poll(cx) { Poll::Ready(Ok(res)) => { @@ -499,9 +502,6 @@ where Poll::Ready(Err(e)) => Either::Right(Poll::Ready(Err(e.into()))), } } - FramedServiceImplResponseInner::Dispatcher(ref mut fut) => { - Either::Right(fut.poll_inner(cx)) - } } } } diff --git a/ntex/src/testing.rs b/ntex/src/testing.rs index 6168e117..ebab9097 100644 --- a/ntex/src/testing.rs +++ b/ntex/src/testing.rs @@ -163,7 +163,10 @@ impl Io { /// Read any available data pub fn remote_buffer_cap(&self, cap: usize) { + // change cap self.local.lock().unwrap().borrow_mut().buf_cap = cap; + // wake remote + self.remote.lock().unwrap().borrow().waker.wake(); } /// Read any available data @@ -268,7 +271,7 @@ impl AsyncRead for Io { impl AsyncWrite for Io { fn poll_write( self: Pin<&mut Self>, - _: &mut Context<'_>, + cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let guard = self.remote.lock().unwrap(); @@ -284,11 +287,25 @@ impl AsyncWrite for Io { ch.waker.wake(); Poll::Ready(Ok(cap)) } else { + self.local + .lock() + .unwrap() + .borrow_mut() + .waker + .register(cx.waker()); Poll::Pending } } IoState::Close => Poll::Ready(Ok(0)), - IoState::Pending => Poll::Pending, + IoState::Pending => { + self.local + .lock() + .unwrap() + .borrow_mut() + .waker + .register(cx.waker()); + Poll::Pending + } IoState::Err(e) => Poll::Ready(Err(e)), } } diff --git a/ntex/src/util/framed.rs b/ntex/src/util/framed.rs index a50b7450..96b33ada 100644 --- a/ntex/src/util/framed.rs +++ b/ntex/src/util/framed.rs @@ -363,9 +363,7 @@ where Poll::Ready(Err(err)) => { debug!("Error sending data: {:?}", err); } - Poll::Pending => { - return Poll::Pending; - } + Poll::Pending => return Poll::Pending, Poll::Ready(_) => (), } }; @@ -459,7 +457,10 @@ mod tests { let framed = Framed::new(server, BytesCodec); let disp = Dispatcher::new( framed, - crate::fn_service(|msg: BytesMut| ok::<_, ()>(Some(msg.freeze()))), + crate::fn_service(|msg: BytesMut| async move { + delay_for(Duration::from_millis(50)).await; + Ok::<_, ()>(Some(msg.freeze())) + }), ); crate::rt::spawn(disp.map(|_| ())); @@ -497,4 +498,36 @@ mod tests { delay_for(Duration::from_millis(100)).await; assert!(client.is_server_dropped()); } + + #[ntex_rt::test] + async fn test_err_in_service() { + let (client, server) = Io::create(); + client.remote_buffer_cap(0); + client.write("GET /test HTTP/1\r\n\r\n"); + + let mut framed = Framed::new(server, BytesCodec); + framed.write_buf().extend(b"GET /test HTTP/1\r\n\r\n"); + + let disp = Dispatcher::new( + framed, + crate::fn_service(|_: BytesMut| async { Err::, _>(()) }), + ); + crate::rt::spawn(disp.map(|_| ())); + + let buf = client.read_any(); + assert_eq!(buf, Bytes::from_static(b"")); + delay_for(Duration::from_millis(25)).await; + + // buffer should be flushed + client.remote_buffer_cap(1024); + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); + + // write side must be closed, dispatcher waiting for read side to close + assert!(client.is_closed()); + + // close read side + client.close().await; + assert!(client.is_server_dropped()); + } } diff --git a/ntex/src/util/order.rs b/ntex/src/util/order.rs index eb4718aa..2542f34d 100644 --- a/ntex/src/util/order.rs +++ b/ntex/src/util/order.rs @@ -3,7 +3,6 @@ use std::collections::VecDeque; use std::convert::Infallible; use std::fmt; use std::future::Future; -use std::marker::PhantomData; use std::pin::Pin; use std::rc::Rc; use std::task::{Context, Poll}; @@ -53,39 +52,26 @@ impl fmt::Display for InOrderError { /// InOrder - The service will yield responses as they become available, /// in the order that their originating requests were submitted to the service. -pub struct InOrder { - _t: PhantomData, -} +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] +pub struct InOrder; -impl InOrder -where - S: Service + 'static, - S::Response: 'static, - S::Future: 'static, - S::Error: 'static, -{ +impl InOrder { pub fn new() -> Self { - Self { _t: PhantomData } + Self } - pub fn service(service: S) -> InOrderService { + pub fn service(service: S) -> InOrderService + where + S: Service + 'static, + S::Response: 'static, + S::Future: 'static, + S::Error: 'static, + { InOrderService::new(service) } } -impl Default for InOrder -where - S: Service + 'static, - S::Response: 'static, - S::Future: 'static, - S::Error: 'static, -{ - fn default() -> Self { - Self::new() - } -} - -impl Transform for InOrder +impl Transform for InOrder where S: Service + 'static, S::Response: 'static, @@ -169,13 +155,7 @@ where } // check nested service - if let Poll::Pending = - self.service.poll_ready(cx).map_err(InOrderError::Service)? - { - Poll::Pending - } else { - Poll::Ready(Ok(())) - } + self.service.poll_ready(cx).map_err(InOrderError::Service) } #[inline] @@ -262,7 +242,7 @@ mod tests { let rx3 = rx3; let tx_stop = tx_stop; let _ = crate::rt::System::new("test").block_on(async { - let srv = InOrderService::new(Srv); + let srv = InOrder::default().new_transform(Srv).await.unwrap(); let _ = lazy(|cx| srv.poll_ready(cx)).await; let res1 = srv.call(rx1); @@ -294,4 +274,18 @@ mod tests { let _ = rx_stop.await; let _ = h.join(); } + + #[test] + fn test_error() { + #[derive(Debug, derive_more::Display)] + struct TestError; + + let e = InOrderError::::Disconnected; + assert!(format!("{:?}", e).contains("InOrderError::Disconnected")); + assert!(format!("{}", e).contains("InOrder service disconnected")); + + let e: InOrderError = TestError.into(); + assert!(format!("{:?}", e).contains("InOrderError::Service(TestError)")); + assert!(format!("{}", e).contains("TestError")); + } } diff --git a/ntex/src/util/stream.rs b/ntex/src/util/stream.rs index 8be5e8ca..931db734 100644 --- a/ntex/src/util/stream.rs +++ b/ntex/src/util/stream.rs @@ -17,7 +17,6 @@ where stream: S, service: T, err_rx: mpsc::Receiver, - err_tx: mpsc::Sender, } impl Dispatcher @@ -29,11 +28,9 @@ where where F: IntoService, { - let (err_tx, err_rx) = mpsc::channel(); Dispatcher { - err_rx, - err_tx, stream, + err_rx: mpsc::channel().1, service: service.into_service(), } } @@ -57,7 +54,7 @@ where return match this.service.poll_ready(cx)? { Poll::Ready(_) => match this.stream.poll_next(cx) { Poll::Ready(Some(item)) => { - let stop = this.err_tx.clone(); + let stop = this.err_rx.sender(); crate::rt::spawn(this.service.call(item).map(move |res| { if let Err(e) = res { let _ = stop.send(e); @@ -74,3 +71,37 @@ where } } } + +#[cfg(test)] +mod tests { + use futures::future::ok; + use std::cell::Cell; + use std::rc::Rc; + use std::time::Duration; + + use super::*; + use crate::channel::mpsc; + use crate::rt::time::delay_for; + + #[ntex_rt::test] + async fn test_basic() { + let (tx, rx) = mpsc::channel(); + let counter = Rc::new(Cell::new(0)); + let counter2 = counter.clone(); + + let disp = Dispatcher::new( + rx, + crate::fn_service(move |_: ()| { + counter2.set(counter2.get() + 1); + ok::<_, ()>(()) + }), + ); + crate::rt::spawn(disp.map(|_| ())); + + tx.send(()).unwrap(); + tx.send(()).unwrap(); + drop(tx); + delay_for(Duration::from_millis(10)).await; + assert_eq!(counter.get(), 2); + } +} diff --git a/ntex/src/util/time.rs b/ntex/src/util/time.rs index 6e4193d7..5a39b24b 100644 --- a/ntex/src/util/time.rs +++ b/ntex/src/util/time.rs @@ -160,8 +160,17 @@ impl SystemTimeService { #[cfg(test)] mod tests { use super::*; + use futures::future::lazy; use std::time::{Duration, SystemTime}; + #[ntex_rt::test] + async fn low_res_timee() { + let f = LowResTime::default(); + let srv = f.new_service(()).await.unwrap(); + assert!(lazy(|cx| srv.poll_ready(cx)).await.is_ready()); + srv.call(()).await.unwrap(); + } + /// State Under Test: Two calls of `SystemTimeService::now()` return the same value if they are done within resolution interval of `SystemTimeService`. /// /// Expected Behavior: Two back-to-back calls of `SystemTimeService::now()` return the same value. diff --git a/ntex/src/web/error.rs b/ntex/src/web/error.rs index cc37f9dc..e0434547 100644 --- a/ntex/src/web/error.rs +++ b/ntex/src/web/error.rs @@ -680,6 +680,7 @@ mod tests { use std::io; use super::*; + use crate::http::client::error::{ConnectError, SendRequestError}; use crate::web::test::TestRequest; use crate::web::DefaultError; @@ -707,26 +708,25 @@ mod tests { let req = TestRequest::default().to_http_request(); use crate::util::timeout::TimeoutError; - let resp: HttpResponse = WebResponseError::::error_response( + let resp = WebResponseError::::error_response( &TimeoutError::::Timeout, &req, ); assert_eq!(resp.status(), StatusCode::GATEWAY_TIMEOUT); - use crate::http::client::error::{ConnectError, SendRequestError}; - let resp: HttpResponse = WebResponseError::::error_response( + let resp = WebResponseError::::error_response( &SendRequestError::Connect(ConnectError::Timeout), &req, ); assert_eq!(resp.status(), StatusCode::GATEWAY_TIMEOUT); - let resp: HttpResponse = WebResponseError::::error_response( + let resp = WebResponseError::::error_response( &SendRequestError::Connect(ConnectError::SslIsNotSupported), &req, ); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = WebResponseError::::error_response( + let resp = WebResponseError::::error_response( &SendRequestError::TunnelNotSupported, &req, ); @@ -741,11 +741,38 @@ mod tests { assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } - let resp: HttpResponse = WebResponseError::::error_response( + let resp = WebResponseError::::error_response( &crate::http::error::ContentTypeError::ParseError, &req, ); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let err = serde_urlencoded::from_str::("bad query").unwrap_err(); + let resp = WebResponseError::::error_response(&err, &req); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + + let err = PayloadError::Decoding; + let resp = WebResponseError::::error_response(&err, &req); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_either_error() { + let req = TestRequest::default().to_http_request(); + + let err: either::Either = + either::Either::Left(SendRequestError::TunnelNotSupported); + let code = WebResponseError::::status_code(&err); + assert_eq!(code, StatusCode::INTERNAL_SERVER_ERROR); + let resp = WebResponseError::::error_response(&err, &req); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + + let err: either::Either = + either::Either::Right(PayloadError::Decoding); + let code = WebResponseError::::status_code(&err); + assert_eq!(code, StatusCode::BAD_REQUEST); + let resp = WebResponseError::::error_response(&err, &req); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } #[test] @@ -842,6 +869,14 @@ mod tests { #[test] fn test_error_helpers() { + let err = ErrorBadRequest::<_, DefaultError>("err"); + assert!(format!("{:?}", err).contains("web::InternalError")); + + let err: InternalError<_, DefaultError> = + InternalError::from_response("err", HttpResponse::BadRequest().finish()); + let r: HttpResponse = err.into(); + assert_eq!(r.status(), StatusCode::BAD_REQUEST); + let r: HttpResponse = ErrorBadRequest::<_, DefaultError>("err").into(); assert_eq!(r.status(), StatusCode::BAD_REQUEST); diff --git a/ntex/src/web/httprequest.rs b/ntex/src/web/httprequest.rs index 9907544e..66d50b93 100644 --- a/ntex/src/web/httprequest.rs +++ b/ntex/src/web/httprequest.rs @@ -351,6 +351,12 @@ mod tests { TestRequest::with_header("content-type", "text/plain").to_http_request(); let dbg = format!("{:?}", req); assert!(dbg.contains("HttpRequest")); + + let req = TestRequest::with_uri("/index.html?q=?").to_http_request(); + let dbg = format!("{:?}", req); + assert!(dbg.contains("HttpRequest")); + assert!(req.peer_addr().is_none()); + assert_eq!(req.method(), &Method::GET); } #[cfg(feature = "cookie")] diff --git a/ntex/src/web/response.rs b/ntex/src/web/response.rs index 39180091..08eb043a 100644 --- a/ntex/src/web/response.rs +++ b/ntex/src/web/response.rs @@ -90,18 +90,17 @@ impl WebResponse { } /// Execute closure and in case of error convert it to response. - pub fn checked_expr(mut self, f: F) -> Self + pub fn checked_expr(mut self, f: F) -> Self where F: FnOnce(&mut Self) -> Result<(), E>, E: Into, Err: ErrorRenderer, { - match f(&mut self) { - Ok(_) => self, - Err(err) => { - let res: Response = err.into().into(); - WebResponse::new(res, self.request) - } + if let Err(err) = f(&mut self) { + let res: Response = err.into().into(); + WebResponse::new(res, self.request) + } else { + self } } @@ -149,3 +148,31 @@ impl fmt::Debug for WebResponse { res } } + +#[cfg(test)] +mod tests { + use crate::http::{self, StatusCode}; + use crate::web::test::TestRequest; + use crate::web::{DefaultError, HttpResponse}; + + #[test] + fn test_response() { + let res = TestRequest::default().to_srv_response(HttpResponse::Ok().finish()); + let res = res.into_response(HttpResponse::BadRequest().finish()); + assert_eq!(res.response().status(), StatusCode::BAD_REQUEST); + + let err = http::error::PayloadError::Overflow; + let res = res.error_response::(err); + assert_eq!(res.response().status(), StatusCode::PAYLOAD_TOO_LARGE); + + let res = TestRequest::default().to_srv_response(HttpResponse::Ok().finish()); + let mut res = res.checked_expr::(|_| { + Ok::<_, http::error::PayloadError>(()) + }); + assert_eq!(res.response_mut().status(), StatusCode::OK); + let res = res.checked_expr::(|_| { + Err(http::error::PayloadError::Overflow) + }); + assert_eq!(res.response().status(), StatusCode::PAYLOAD_TOO_LARGE); + } +} diff --git a/ntex/src/web/types/form.rs b/ntex/src/web/types/form.rs index eee91a39..f946fc0d 100644 --- a/ntex/src/web/types/form.rs +++ b/ntex/src/web/types/form.rs @@ -130,7 +130,7 @@ where impl fmt::Debug for Form { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) + f.debug_tuple("Form").field(&self.0).finish() } } @@ -359,12 +359,26 @@ mod tests { use crate::http::header::{HeaderValue, CONTENT_TYPE}; use crate::web::test::{from_request, respond_to, TestRequest}; - #[derive(Deserialize, Serialize, Debug, PartialEq)] + #[derive(Deserialize, Serialize, Debug, PartialEq, derive_more::Display)] + #[display(fmt = "{}", "hello")] struct Info { hello: String, counter: i64, } + #[test] + fn test_basic() { + let mut f = Form(Info { + hello: "world".into(), + counter: 123, + }); + assert_eq!(f.hello, "world"); + f.hello = "test".to_string(); + assert_eq!(f.hello, "test"); + assert!(format!("{:?}", f).contains("Form")); + assert!(format!("{}", f).contains("test")); + } + #[ntex_rt::test] async fn test_form() { let (req, mut pl) = diff --git a/ntex/src/web/types/json.rs b/ntex/src/web/types/json.rs index 820a4a4a..a98de727 100644 --- a/ntex/src/web/types/json.rs +++ b/ntex/src/web/types/json.rs @@ -104,7 +104,7 @@ where T: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Json: {:?}", self.0) + f.debug_tuple("Json").field(&self.0).finish() } } @@ -393,7 +393,7 @@ mod tests { use crate::http::header; use crate::web::test::{from_request, respond_to, TestRequest}; - #[derive(Serialize, Deserialize, PartialEq, Debug)] + #[derive(Serialize, Deserialize, PartialEq, Debug, derive_more::Display)] struct MyObject { name: String, } @@ -412,6 +412,18 @@ mod tests { } } + #[test] + fn test_json() { + let mut j = Json(MyObject { + name: "test2".to_string(), + }); + assert_eq!(j.name, "test2"); + j.name = "test".to_string(); + assert_eq!(j.name, "test"); + assert!(format!("{:?}", j).contains("Json")); + assert!(format!("{}", j).contains("test")); + } + #[ntex_rt::test] async fn test_responder() { let req = TestRequest::default().to_http_request();