diff --git a/ntex-codec/CHANGES.md b/ntex-codec/CHANGES.md index cd65e31a..95c2085f 100644 --- a/ntex-codec/CHANGES.md +++ b/ntex-codec/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.3.0] - 2021-01-23 + +* Make Encoder and Decoder methods immutable + ## [0.2.2] - 2021-01-21 * Flush underlying io stream diff --git a/ntex-codec/Cargo.toml b/ntex-codec/Cargo.toml index 82a2747c..7f0263d3 100644 --- a/ntex-codec/Cargo.toml +++ b/ntex-codec/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-codec" -version = "0.2.2" +version = "0.3.0-b.1" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] @@ -18,12 +18,12 @@ path = "src/lib.rs" [dependencies] bitflags = "1.2.1" bytes = "0.5.6" -either = "1.5.3" +either = "1.6.1" futures-core = "0.3.12" futures-sink = "0.3.12" log = "0.4" tokio = { version = "0.2.6", default-features=false } [dev-dependencies] -ntex = "0.2.0-b.2" +ntex = "0.2.0-b.3" futures = "0.3.12" diff --git a/ntex-codec/src/bcodec.rs b/ntex-codec/src/bcodec.rs index fc2ef292..604a1878 100644 --- a/ntex-codec/src/bcodec.rs +++ b/ntex-codec/src/bcodec.rs @@ -14,7 +14,7 @@ impl Encoder for BytesCodec { type Error = io::Error; #[inline] - fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> { + fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> { dst.extend_from_slice(item.bytes()); Ok(()) } @@ -24,7 +24,7 @@ impl Decoder for BytesCodec { type Item = BytesMut; type Error = io::Error; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { if src.is_empty() { Ok(None) } else { diff --git a/ntex-codec/src/decoder.rs b/ntex-codec/src/decoder.rs index e5cdcae8..a5dfec3e 100644 --- a/ntex-codec/src/decoder.rs +++ b/ntex-codec/src/decoder.rs @@ -1,4 +1,5 @@ use bytes::BytesMut; +use std::rc::Rc; /// Decoding of frames via buffers. pub trait Decoder { @@ -13,7 +14,7 @@ pub trait Decoder { type Error: std::fmt::Debug; /// Attempts to decode a frame from the provided buffer of bytes. - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error>; + fn decode(&self, src: &mut BytesMut) -> Result, Self::Error>; /// A default method available to be called when there are no more bytes /// available to be read from the underlying I/O. @@ -22,13 +23,26 @@ pub trait Decoder { /// `Ok(None)` is returned while there is unconsumed data in `buf`. /// Typically this doesn't need to be implemented unless the framing /// protocol differs near the end of the stream. - fn decode_eof( - &mut self, - buf: &mut BytesMut, - ) -> Result, Self::Error> { + fn decode_eof(&self, buf: &mut BytesMut) -> Result, Self::Error> { match self.decode(buf)? { Some(frame) => Ok(Some(frame)), None => Ok(None), } } } + +impl Decoder for Rc +where + T: Decoder, +{ + type Item = T::Item; + type Error = T::Error; + + fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { + (**self).decode(src) + } + + fn decode_eof(&self, src: &mut BytesMut) -> Result, Self::Error> { + (**self).decode_eof(src) + } +} diff --git a/ntex-codec/src/encoder.rs b/ntex-codec/src/encoder.rs index 85ccef35..cd8166ce 100644 --- a/ntex-codec/src/encoder.rs +++ b/ntex-codec/src/encoder.rs @@ -1,4 +1,5 @@ use bytes::BytesMut; +use std::rc::Rc; /// Trait of helper objects to write out messages as bytes. pub trait Encoder { @@ -9,9 +10,17 @@ pub trait Encoder { type Error: std::fmt::Debug; /// Encodes a frame into the buffer provided. - fn encode( - &mut self, - item: Self::Item, - dst: &mut BytesMut, - ) -> Result<(), Self::Error>; + fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error>; +} + +impl Encoder for Rc +where + T: Encoder, +{ + type Item = T::Item; + type Error = T::Error; + + fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { + (**self).encode(item, dst) + } } diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index 968a5d64..0e7fd3fd 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,15 @@ # Changes +## [0.2.0-b.4] - 2021-01-xx + +* http: Refactor h1 dispatcher + +* http: Remove generic type from `Request` + +* http: Remove generic type from `Payload` + +* Rename FrameReadTask/FramedWriteTask to ReadTask/WriteTask + ## [0.2.0-b.3] - 2021-01-21 * Allow to use framed write task for io flushing diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 3a439f91..4176cdbd 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.2.0-b.3" +version = "0.2.0-b.4" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -36,7 +36,7 @@ compress = ["flate2", "brotli2"] cookie = ["coo-kie", "coo-kie/percent-encode"] [dependencies] -ntex-codec = "0.2.2" +ntex-codec = "0.3.0-b.1" ntex-rt = "0.1.1" ntex-rt-macros = "0.1" ntex-router = "0.3.8" @@ -49,7 +49,7 @@ bitflags = "1.2.1" bytes = "0.5.6" bytestring = "0.1.5" derive_more = "0.99.11" -either = "1.5.3" +either = "1.6.1" encoding_rs = "0.8.26" futures = "0.3.12" ahash = "0.6.3" @@ -104,3 +104,6 @@ serde_derive = "1.0" open-ssl = { version="0.10", package = "openssl" } rust-tls = { version = "0.19.0", package="rustls", features = ["dangerous_configuration"] } webpki = "0.21.2" + +[patch.crates-io] +ntex = { path = "../ntex-codec" } diff --git a/ntex/src/framed/dispatcher.rs b/ntex/src/framed/dispatcher.rs index 199728d5..fe14d54b 100644 --- a/ntex/src/framed/dispatcher.rs +++ b/ntex/src/framed/dispatcher.rs @@ -9,7 +9,7 @@ use either::Either; use futures::FutureExt; use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder}; -use crate::framed::{DispatcherItem, FramedReadTask, FramedWriteTask, State, Timer}; +use crate::framed::{DispatchItem, ReadTask, State, Timer, WriteTask}; use crate::service::{IntoService, Service}; type Response = ::Item; @@ -19,7 +19,7 @@ pin_project_lite::pin_project! { /// and pass then to the service. pub struct Dispatcher where - S: Service, Response = Option>>, + S: Service, Response = Option>>, S::Error: 'static, S::Future: 'static, U: Encoder, @@ -27,12 +27,7 @@ pin_project_lite::pin_project! { ::Item: 'static, { service: S, - state: State, - inner: Rc>, - st: DispatcherState, - timer: Timer, - updated: Instant, - keepalive_timeout: u16, + inner: DispatcherInner, #[pin] response: Option, } @@ -40,12 +35,23 @@ pin_project_lite::pin_project! { struct DispatcherInner where - S: Service, Response = Option>>, - S::Error: 'static, - S::Future: 'static, + S: Service, Response = Option>>, U: Encoder + Decoder, - ::Item: 'static, { + st: DispatcherState, + state: State, + timer: Timer, + updated: Instant, + keepalive_timeout: u16, + shared: Rc>, +} + +struct DispatcherShared +where + S: Service, Response = Option>>, + U: Encoder + Decoder, +{ + codec: U, error: Cell::Error>>>, inflight: Cell, } @@ -73,13 +79,13 @@ impl From> for DispatcherError { } impl DispatcherError { - fn convert(self) -> Option> + fn convert(self) -> Option> where U: Encoder + Decoder, { match self { - DispatcherError::KeepAlive => Some(DispatcherItem::KeepAliveTimeout), - DispatcherError::Encoder(err) => Some(DispatcherItem::EncoderError(err)), + DispatcherError::KeepAlive => Some(DispatchItem::KeepAliveTimeout), + DispatcherError::Encoder(err) => Some(DispatchItem::EncoderError(err)), DispatcherError::Service(_) => None, } } @@ -87,46 +93,59 @@ impl DispatcherError { impl Dispatcher where - S: Service, Response = Option>> + 'static, + S: Service, Response = Option>> + 'static, U: Decoder + Encoder + 'static, ::Item: 'static, { /// Construct new `Dispatcher` instance. pub fn new>( io: T, - state: State, + codec: U, + state: State, service: F, - timer: Timer, + timer: Timer, ) -> Self where T: AsyncRead + AsyncWrite + Unpin + 'static, { + let io = Rc::new(RefCell::new(io)); + + // start support tasks + crate::rt::spawn(ReadTask::new(io.clone(), state.clone())); + crate::rt::spawn(WriteTask::new(io, state.clone())); + + Self::from_state(codec, state, service, timer) + } + + /// Construct new `Dispatcher` instance. + pub fn from_state>( + codec: U, + state: State, + service: F, + timer: Timer, + ) -> Self { let updated = timer.now(); let keepalive_timeout: u16 = 30; - let io = Rc::new(RefCell::new(io)); // register keepalive timer let expire = updated + Duration::from_secs(keepalive_timeout as u64); timer.register(expire, expire, &state); - // start support tasks - crate::rt::spawn(FramedReadTask::new(io.clone(), state.clone())); - crate::rt::spawn(FramedWriteTask::new(io, state.clone())); - - let inner = Rc::new(DispatcherInner { - error: Cell::new(None), - inflight: Cell::new(0), - }); - Dispatcher { - st: DispatcherState::Processing, service: service.into_service(), response: None, - state, - inner, - timer, - updated, - keepalive_timeout, + inner: DispatcherInner { + state, + timer, + updated, + keepalive_timeout, + st: DispatcherState::Processing, + shared: Rc::new(DispatcherShared { + codec, + error: Cell::new(None), + inflight: Cell::new(0), + }), + }, } } @@ -137,15 +156,15 @@ where /// By default keep-alive timeout is set to 30 seconds. pub fn keepalive_timeout(mut self, timeout: u16) -> Self { // register keepalive timer - let prev = self.updated + Duration::from_secs(self.keepalive_timeout as u64); + let prev = self.inner.updated + + Duration::from_secs(self.inner.keepalive_timeout as u64); if timeout == 0 { - self.timer.unregister(prev, &self.state); + self.inner.timer.unregister(prev, &self.inner.state); } else { - let expire = self.updated + Duration::from_secs(timeout as u64); - self.timer.register(expire, prev, &self.state); + let expire = self.inner.updated + Duration::from_secs(timeout as u64); + self.inner.timer.register(expire, prev, &self.inner.state); } - - self.keepalive_timeout = timeout; + self.inner.keepalive_timeout = timeout; self } @@ -159,14 +178,14 @@ where /// /// By default disconnect timeout is set to 1 seconds. pub fn disconnect_timeout(self, val: u16) -> Self { - self.state.set_disconnect_timeout(val); + self.inner.state.set_disconnect_timeout(val); self } } -impl DispatcherInner +impl DispatcherShared where - S: Service, Response = Option>>, + S: Service, Response = Option>>, S::Error: 'static, S::Future: 'static, U: Encoder + Decoder, @@ -175,11 +194,11 @@ where fn handle_result( &self, item: Result, - state: &State, + state: &State, wake: bool, ) { self.inflight.set(self.inflight.get() - 1); - if let Err(err) = state.write_result(item) { + if let Err(err) = state.write_result(item, &self.codec) { self.error.set(Some(err.into())); } @@ -189,9 +208,60 @@ where } } +impl DispatcherInner +where + S: Service, Response = Option>>, + U: Decoder + Encoder, +{ + fn take_error(&self) -> Option> { + // check for errors + self.shared + .error + .take() + .and_then(|err| err.convert()) + .or_else(|| self.state.take_io_error().map(DispatchItem::IoError)) + } + + /// check keepalive timeout + fn check_keepalive(&self) { + if self.state.is_keepalive() { + log::trace!("keepalive timeout"); + if let Some(err) = self.shared.error.take() { + self.shared.error.set(Some(err)); + } else { + self.shared.error.set(Some(DispatcherError::KeepAlive)); + } + self.state.dsp_mark_stopped(); + } + } + + /// update keep-alive timer + fn update_keepalive(&mut self) { + if self.keepalive_timeout != 0 { + let updated = self.timer.now(); + if updated != self.updated { + let ka = Duration::from_secs(self.keepalive_timeout as u64); + self.timer + .register(updated + ka, self.updated + ka, &self.state); + self.updated = updated; + } + } + } + + /// unregister keep-alive timer + fn unregister_keepalive(&self) { + if self.keepalive_timeout != 0 { + self.timer.unregister( + self.updated + Duration::from_secs(self.keepalive_timeout as u64), + &self.state, + ); + } + } +} + impl Future for Dispatcher where - S: Service, Response = Option>> + 'static, + S: Service, Response = Option>> + 'static, U: Decoder + Encoder + 'static, ::Item: 'static, { @@ -200,125 +270,70 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.as_mut().project(); - // log::trace!("IO-DISP poll :{:?}:", this.st); - // handle service response future if let Some(fut) = this.response.as_mut().as_pin_mut() { match fut.poll(cx) { Poll::Pending => (), Poll::Ready(item) => { - this.inner.handle_result(item, this.state, false); + this.inner + .shared + .handle_result(item, &this.inner.state, false); this.response.set(None); } } } - match this.st { + match this.inner.st { DispatcherState::Processing => { loop { - // log::trace!("IO-DISP state :{:?}:", this.state.get_flags()); - match this.service.poll_ready(cx) { Poll::Ready(Ok(_)) => { let mut retry = false; // service is ready, wake io read task - this.state.dsp_restart_read_task(); + this.inner.state.dsp_restart_read_task(); - let item = if this.state.is_dsp_stopped() { + // check keepalive timeout + this.inner.check_keepalive(); + + let item = if this.inner.state.is_dsp_stopped() { log::trace!("dispatcher is instructed to stop"); - // check keepalive timeout - if this.state.is_keepalive_err() { - if let Some(err) = this.inner.error.take() { - this.inner.error.set(Some(err)); - } else { - this.inner - .error - .set(Some(DispatcherError::KeepAlive)); - } - } else if *this.keepalive_timeout != 0 { - // unregister keep-alive timer - this.timer.unregister( - *this.updated - + Duration::from_secs( - *this.keepalive_timeout as u64, - ), - this.state, - ); - } + + // unregister keep-alive timer + this.inner.unregister_keepalive(); // check for errors - let item = this - .inner - .error - .take() - .and_then(|err| err.convert()) - .or_else(|| { - this.state - .take_io_error() - .map(DispatcherItem::IoError) - }); - *this.st = DispatcherState::Stop; retry = true; - - item + this.inner.st = DispatcherState::Stop; + this.inner.take_error() } else { // decode incoming bytes stream - - if this.state.is_read_ready() { - // this.state.with_read_buf(|buf| { - // log::trace!( - // "attempt to decode frame, buffer size is {:?}", - // buf - // ); - // }); - - match this.state.decode_item() { + if this.inner.state.is_read_ready() { + let item = this + .inner + .state + .decode_item(&this.inner.shared.codec); + match item { Ok(Some(el)) => { - // update keep-alive timer - if *this.keepalive_timeout != 0 { - let updated = this.timer.now(); - if updated != *this.updated { - let ka = Duration::from_secs( - *this.keepalive_timeout as u64, - ); - this.timer.register( - updated + ka, - *this.updated + ka, - this.state, - ); - *this.updated = updated; - } - } - - Some(DispatcherItem::Item(el)) + this.inner.update_keepalive(); + Some(DispatchItem::Item(el)) } Ok(None) => { - // log::trace!("not enough data to decode next frame, register dispatch task"); - this.state.dsp_read_more_data(cx.waker()); + log::trace!("not enough data to decode next frame, register dispatch task"); + this.inner + .state + .dsp_read_more_data(cx.waker()); return Poll::Pending; } Err(err) => { retry = true; - *this.st = DispatcherState::Stop; - - // unregister keep-alive timer - if *this.keepalive_timeout != 0 { - this.timer.unregister( - *this.updated - + Duration::from_secs( - *this.keepalive_timeout - as u64, - ), - this.state, - ); - } - - Some(DispatcherItem::DecoderError(err)) + this.inner.st = DispatcherState::Stop; + this.inner.unregister_keepalive(); + Some(DispatchItem::DecoderError(err)) } } } else { - this.state.dsp_register_task(cx.waker()); + this.inner.state.dsp_register_task(cx.waker()); return Poll::Pending; } }; @@ -336,24 +351,33 @@ where .poll(cx); if let Poll::Ready(res) = res { - if let Err(err) = this.state.write_result(res) { - this.inner.error.set(Some(err.into())); + if let Err(err) = this + .inner + .state + .write_result(res, &this.inner.shared.codec) + { + this.inner + .shared + .error + .set(Some(err.into())); } this.response.set(None); } else { this.inner + .shared .inflight - .set(this.inner.inflight.get() + 1); + .set(this.inner.shared.inflight.get() + 1); } } else { this.inner + .shared .inflight - .set(this.inner.inflight.get() + 1); - let st = this.state.clone(); - let inner = this.inner.clone(); + .set(this.inner.shared.inflight.get() + 1); + let st = this.inner.state.clone(); + let shared = this.inner.shared.clone(); crate::rt::spawn(this.service.call(item).map( move |item| { - inner.handle_result(item, &st, true); + shared.handle_result(item, &st, true); }, )); } @@ -367,27 +391,19 @@ where Poll::Pending => { // pause io read task log::trace!("service is not ready, register dispatch task"); - this.state.dsp_service_not_ready(cx.waker()); + this.inner.state.dsp_service_not_ready(cx.waker()); return Poll::Pending; } Poll::Ready(Err(err)) => { + // handle service readiness error log::trace!("service readiness check failed, stopping"); - // service readiness error - *this.st = DispatcherState::Stop; - this.state.dsp_mark_stopped(); - this.inner.error.set(Some(DispatcherError::Service(err))); - - // unregister keep-alive timer - if *this.keepalive_timeout != 0 { - this.timer.unregister( - *this.updated - + Duration::from_secs( - *this.keepalive_timeout as u64, - ), - this.state, - ); - } - + this.inner.st = DispatcherState::Stop; + this.inner.state.dsp_mark_stopped(); + this.inner + .shared + .error + .set(Some(DispatcherError::Service(err))); + this.inner.unregister_keepalive(); return self.poll(cx); } } @@ -398,18 +414,18 @@ where // service may relay on poll_ready for response results let _ = this.service.poll_ready(cx); - if this.inner.inflight.get() == 0 { - this.state.shutdown_io(); - *this.st = DispatcherState::Shutdown; + if this.inner.shared.inflight.get() == 0 { + this.inner.state.shutdown_io(); + this.inner.st = DispatcherState::Shutdown; self.poll(cx) } else { - this.state.dsp_register_task(cx.waker()); + this.inner.state.dsp_register_task(cx.waker()); Poll::Pending } } // shutdown service DispatcherState::Shutdown => { - let err = this.inner.error.take(); + let err = this.inner.shared.error.take(); if this.service.poll_shutdown(cx, err.is_some()).is_ready() { log::trace!("service shutdown is completed, stop"); @@ -420,7 +436,7 @@ where Ok(()) }) } else { - this.inner.error.set(err); + this.inner.shared.error.set(err); Poll::Pending } } @@ -441,7 +457,7 @@ mod tests { impl Dispatcher where - S: Service, Response = Option>>, + S: Service, Response = Option>>, S::Error: 'static, S::Future: 'static, U: Decoder + Encoder + 'static, @@ -452,33 +468,36 @@ mod tests { io: T, codec: U, service: F, - ) -> (Self, State) + ) -> (Self, State) where T: AsyncRead + AsyncWrite + Unpin + 'static, { - let timer = Timer::with(Duration::from_secs(1)); + let timer = Timer::default(); let keepalive_timeout = 30; let updated = timer.now(); - let state = State::new(codec); + let state = State::new(); let io = Rc::new(RefCell::new(io)); - let inner = Rc::new(DispatcherInner { + let shared = Rc::new(DispatcherShared { + codec: codec, error: Cell::new(None), inflight: Cell::new(0), }); - crate::rt::spawn(FramedReadTask::new(io.clone(), state.clone())); - crate::rt::spawn(FramedWriteTask::new(io.clone(), state.clone())); + crate::rt::spawn(ReadTask::new(io.clone(), state.clone())); + crate::rt::spawn(WriteTask::new(io.clone(), state.clone())); ( Dispatcher { service: service.into_service(), - state: state.clone(), - st: DispatcherState::Processing, response: None, - timer, - updated, - keepalive_timeout, - inner, + inner: DispatcherInner { + shared, + timer, + updated, + keepalive_timeout, + state: state.clone(), + st: DispatcherState::Processing, + }, }, state, ) @@ -494,9 +513,9 @@ mod tests { let (disp, _) = Dispatcher::debug( server, BytesCodec, - crate::fn_service(|msg: DispatcherItem| async move { + crate::fn_service(|msg: DispatchItem| async move { delay_for(Duration::from_millis(50)).await; - if let DispatcherItem::Item(msg) = msg { + if let DispatchItem::Item(msg) = msg { Ok::<_, ()>(Some(msg.freeze())) } else { panic!() @@ -521,8 +540,8 @@ mod tests { let (disp, st) = Dispatcher::debug( server, BytesCodec, - crate::fn_service(|msg: DispatcherItem| async move { - if let DispatcherItem::Item(msg) = msg { + crate::fn_service(|msg: DispatchItem| async move { + if let DispatchItem::Item(msg) = msg { Ok::<_, ()>(Some(msg.freeze())) } else { panic!() @@ -534,7 +553,9 @@ mod tests { let buf = client.read().await.unwrap(); assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); - assert!(st.write_item(Bytes::from_static(b"test")).is_ok()); + assert!(st + .write_item(Bytes::from_static(b"test"), &mut BytesCodec) + .is_ok()); let buf = client.read().await.unwrap(); assert_eq!(buf, Bytes::from_static(b"test")); @@ -552,14 +573,17 @@ mod tests { let (disp, state) = Dispatcher::debug( server, BytesCodec, - crate::fn_service(|_: DispatcherItem| async move { + crate::fn_service(|_: DispatchItem| async move { Err::, _>(()) }), ); crate::rt::spawn(disp.map(|_| ())); state - .write_item(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")) + .write_item( + Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), + &mut BytesCodec, + ) .unwrap(); let buf = client.read_any(); diff --git a/ntex/src/framed/mod.rs b/ntex/src/framed/mod.rs index 6761f8f6..b459d21a 100644 --- a/ntex/src/framed/mod.rs +++ b/ntex/src/framed/mod.rs @@ -1,3 +1,5 @@ +use std::{fmt, io}; + mod dispatcher; mod read; mod state; @@ -5,7 +7,48 @@ mod time; mod write; pub use self::dispatcher::Dispatcher; -pub use self::read::FramedReadTask; -pub use self::state::{DispatcherItem, State}; +pub use self::read::ReadTask; +pub use self::state::State; pub use self::time::Timer; -pub use self::write::FramedWriteTask; +pub use self::write::WriteTask; + +use crate::codec::{Decoder, Encoder}; + +/// Framed transport item +pub enum DispatchItem { + Item(::Item), + /// Keep alive timeout + KeepAliveTimeout, + /// Decoder parse error + DecoderError(::Error), + /// Encoder parse error + EncoderError(::Error), + /// Unexpected io error + IoError(io::Error), +} + +impl fmt::Debug for DispatchItem +where + U: Encoder + Decoder, + ::Item: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DispatchItem::Item(ref item) => { + write!(fmt, "DispatchItem::Item({:?})", item) + } + DispatchItem::KeepAliveTimeout => { + write!(fmt, "DispatchItem::KeepAliveTimeout") + } + DispatchItem::EncoderError(ref e) => { + write!(fmt, "DispatchItem::EncoderError({:?})", e) + } + DispatchItem::DecoderError(ref e) => { + write!(fmt, "DispatchItem::DecoderError({:?})", e) + } + DispatchItem::IoError(ref e) => { + write!(fmt, "DispatchItem::IoError({:?})", e) + } + } + } +} diff --git a/ntex/src/framed/read.rs b/ntex/src/framed/read.rs index 79344109..f2736b25 100644 --- a/ntex/src/framed/read.rs +++ b/ntex/src/framed/read.rs @@ -11,25 +11,25 @@ const LW: usize = 1024; const HW: usize = 8 * 1024; /// Read io task -pub struct FramedReadTask +pub struct ReadTask where T: AsyncRead + AsyncWrite + Unpin, { io: Rc>, - state: State, + state: State, } -impl FramedReadTask +impl ReadTask where T: AsyncRead + AsyncWrite + Unpin, { /// Create new read io task - pub fn new(io: Rc>, state: State) -> Self { + pub fn new(io: Rc>, state: State) -> Self { Self { io, state } } } -impl Future for FramedReadTask +impl Future for ReadTask where T: AsyncRead + AsyncWrite + Unpin, { diff --git a/ntex/src/framed/state.rs b/ntex/src/framed/state.rs index 42c26296..e652d75c 100644 --- a/ntex/src/framed/state.rs +++ b/ntex/src/framed/state.rs @@ -1,6 +1,6 @@ //! Framed transport dispatcher use std::task::{Context, Poll, Waker}; -use std::{cell::Cell, cell::RefCell, fmt, hash, io, mem, pin::Pin, rc::Rc}; +use std::{cell::Cell, cell::RefCell, hash, io, mem, pin::Pin, rc::Rc}; use bytes::BytesMut; use either::Either; @@ -10,13 +10,10 @@ use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts} use crate::framed::write::flush; use crate::task::LocalWaker; -type Request = ::Item; -type Response = ::Item; - const HW: usize = 8 * 1024; bitflags::bitflags! { - pub(crate) struct Flags: u8 { + pub struct Flags: u8 { const DSP_STOP = 0b0000_0001; const DSP_KEEPALIVE = 0b0000_0010; @@ -35,49 +32,9 @@ bitflags::bitflags! { } } -/// Framed transport item -pub enum DispatcherItem { - Item(Request), - /// Keep alive timeout - KeepAliveTimeout, - /// Decoder parse error - DecoderError(::Error), - /// Encoder parse error - EncoderError(::Error), - /// Unexpected io error - IoError(io::Error), -} +pub struct State(Rc); -impl fmt::Debug for DispatcherItem -where - U: Encoder + Decoder, - ::Item: fmt::Debug, -{ - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - DispatcherItem::Item(ref item) => { - write!(fmt, "DispatcherItem::Item({:?})", item) - } - DispatcherItem::KeepAliveTimeout => { - write!(fmt, "DispatcherItem::KeepAliveTimeout") - } - DispatcherItem::EncoderError(ref e) => { - write!(fmt, "DispatcherItem::EncoderError({:?})", e) - } - DispatcherItem::DecoderError(ref e) => { - write!(fmt, "DispatcherItem::DecoderError({:?})", e) - } - DispatcherItem::IoError(ref e) => { - write!(fmt, "DispatcherItem::IoError({:?})", e) - } - } - } -} - -pub struct State(Rc>); - -pub(crate) struct IoStateInner { - codec: RefCell, +pub(crate) struct IoStateInner { flags: Cell, error: Cell>, disconnect_timeout: Cell, @@ -88,39 +45,29 @@ pub(crate) struct IoStateInner { write_buf: RefCell, } -impl State { - pub(crate) fn keepalive_timeout(&self) { - let state = self.0.as_ref(); - let mut flags = state.flags.get(); - flags.insert(Flags::DSP_STOP | Flags::DSP_KEEPALIVE); - state.flags.set(flags); - state.dispatch_task.wake(); - } -} - -impl Clone for State { +impl Clone for State { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl Eq for State {} +impl Eq for State {} -impl PartialEq for State { +impl PartialEq for State { fn eq(&self, other: &Self) -> bool { Rc::as_ptr(&self.0) == Rc::as_ptr(&other.0) } } -impl hash::Hash for State { +impl hash::Hash for State { fn hash(&self, state: &mut H) { Rc::as_ptr(&self.0).hash(state); } } -impl State { +impl State { /// Create `State` instance - pub fn new(codec: U) -> Self { + pub fn new() -> Self { State(Rc::new(IoStateInner { flags: Cell::new(Flags::empty()), error: Cell::new(None), @@ -128,14 +75,13 @@ impl State { dispatch_task: LocalWaker::new(), read_task: LocalWaker::new(), write_task: LocalWaker::new(), - codec: RefCell::new(codec), read_buf: RefCell::new(BytesMut::new()), write_buf: RefCell::new(BytesMut::new()), })) } /// Create `State` from Framed - pub fn from_framed(framed: Framed) -> (Io, Self) { + pub fn from_framed(framed: Framed) -> (Io, U, Self) { let parts = framed.into_parts(); let state = State(Rc::new(IoStateInner { @@ -145,30 +91,38 @@ impl State { dispatch_task: LocalWaker::new(), read_task: LocalWaker::new(), write_task: LocalWaker::new(), - codec: RefCell::new(parts.codec), read_buf: RefCell::new(parts.read_buf), write_buf: RefCell::new(parts.write_buf), })); - (parts.io, state) + (parts.io, parts.codec, state) } /// Convert state to a Framed instance - pub fn into_framed(self, io: Io) -> Result, Io> { - match Rc::try_unwrap(self.0) { - Ok(inner) => { - let mut parts = FramedParts::new(io, inner.codec.into_inner()); - parts.read_buf = inner.read_buf.into_inner(); - parts.write_buf = inner.write_buf.into_inner(); - Ok(Framed::from_parts(parts)) - } - Err(_) => Err(io), - } + pub fn into_framed(self, io: Io, codec: U) -> Framed { + let mut parts = FramedParts::new(io, codec); + parts.read_buf = mem::take(&mut self.0.read_buf.borrow_mut()); + parts.write_buf = mem::take(&mut self.0.write_buf.borrow_mut()); + Framed::from_parts(parts) + } + + pub(crate) fn keepalive_timeout(&self) { + let state = self.0.as_ref(); + let mut flags = state.flags.get(); + flags.insert(Flags::DSP_KEEPALIVE); + state.flags.set(flags); + state.dispatch_task.wake(); } pub(super) fn disconnect_timeout(&self) -> u16 { self.0.disconnect_timeout.get() } + #[inline] + /// Get current state flags + pub fn flags(&self) -> Flags { + self.0.flags.get() + } + #[inline] /// Set disconnecto timeout pub fn set_disconnect_timeout(&self, timeout: u16) { @@ -212,10 +166,18 @@ impl State { #[inline] /// Check if keep-alive timeout occured - pub fn is_keepalive_err(&self) -> bool { + pub fn is_keepalive(&self) -> bool { self.0.flags.get().contains(Flags::DSP_KEEPALIVE) } + #[inline] + /// Reset keep-alive error + pub fn reset_keepalive(&self) { + let mut flags = self.0.flags.get(); + flags.remove(Flags::DSP_KEEPALIVE); + self.0.flags.set(flags); + } + #[inline] /// Check is dispatcher marked stopped pub fn is_dsp_stopped(&self) -> bool { @@ -377,6 +339,7 @@ impl State { } #[inline] + /// Get mut access to read buffer pub fn with_read_buf(&self, f: F) -> R where F: FnOnce(&mut BytesMut) -> R, @@ -385,6 +348,7 @@ impl State { } #[inline] + /// Get mut access to write buffer pub fn with_write_buf(&self, f: F) -> R where F: FnOnce(&mut BytesMut) -> R, @@ -393,56 +357,31 @@ impl State { } } -impl State -where - U: Encoder + Decoder, -{ +impl State { #[inline] - /// Consume the `IoState`, returning `IoState` with different codec. - pub fn map_codec(self, f: F) -> State + /// Attempts to decode a frame from the read buffer. + pub fn decode_item( + &self, + codec: &U, + ) -> Result::Item>, ::Error> where - F: Fn(&U) -> U2, - U2: Encoder + Decoder, + U: Decoder, { - let st = self.0.as_ref(); - let codec = f(&st.codec.borrow()); - - State(Rc::new(IoStateInner { - codec: RefCell::new(codec), - flags: Cell::new(st.flags.get()), - error: Cell::new(st.error.take()), - disconnect_timeout: Cell::new(st.disconnect_timeout.get()), - dispatch_task: LocalWaker::new(), - read_task: LocalWaker::new(), - write_task: LocalWaker::new(), - read_buf: RefCell::new(mem::take(&mut st.read_buf.borrow_mut())), - write_buf: RefCell::new(mem::take(&mut st.write_buf.borrow_mut())), - })) + codec.decode(&mut self.0.read_buf.borrow_mut()) } #[inline] - pub fn with_codec(&self, f: F) -> R - where - F: FnOnce(&mut U) -> R, - { - f(&mut *self.0.codec.borrow_mut()) - } - - #[inline] - pub async fn next( + pub async fn next( &self, io: &mut T, - ) -> Result::Item>, Either<::Error, io::Error>> + codec: &mut U, + ) -> Result, Either> where T: AsyncRead + AsyncWrite + Unpin, + U: Decoder, { loop { - let item = { - self.0 - .codec - .borrow_mut() - .decode(&mut self.0.read_buf.borrow_mut()) - }; + let item = codec.decode(&mut self.0.read_buf.borrow_mut()); return match item { Ok(Some(el)) => Ok(Some(el)), Ok(None) => { @@ -468,18 +407,17 @@ where } #[inline] - pub fn poll_next( + pub fn poll_next( &self, io: &mut T, + codec: &mut U, cx: &mut Context<'_>, - ) -> Poll< - Result::Item>, Either<::Error, io::Error>>, - > + ) -> Poll, Either>> where T: AsyncRead + AsyncWrite + Unpin, + U: Decoder, { let mut buf = self.0.read_buf.borrow_mut(); - let mut codec = self.0.codec.borrow_mut(); loop { return match codec.decode(&mut buf) { @@ -502,17 +440,18 @@ where } #[inline] - pub async fn send( + /// Encode item, send to a peer and flush + pub async fn send( &self, io: &mut T, - item: ::Item, - ) -> Result<(), Either<::Error, io::Error>> + codec: &U, + item: U::Item, + ) -> Result<(), Either> where T: AsyncRead + AsyncWrite + Unpin, + U: Encoder, { - self.0 - .codec - .borrow_mut() + codec .encode(item, &mut self.0.write_buf.borrow_mut()) .map_err(Either::Left)?; @@ -525,25 +464,18 @@ where }) } - #[inline] - /// Attempts to decode a frame from the read buffer. - pub fn decode_item( - &self, - ) -> Result::Item>, ::Error> { - self.0 - .codec - .borrow_mut() - .decode(&mut self.0.read_buf.borrow_mut()) - } - #[inline] /// Write item to a buf and wake up io task /// /// Returns state of write buffer state, false is returned if write buffer if full. - pub fn write_item( + pub fn write_item( &self, - item: ::Item, - ) -> Result::Error> { + item: U::Item, + codec: &U, + ) -> Result::Error> + where + U: Encoder, + { let flags = self.0.flags.get(); if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { @@ -551,10 +483,7 @@ where let is_write_sleep = write_buf.is_empty(); // encode item and wake write task - let res = self - .0 - .codec - .borrow_mut() + let res = codec .encode(item, &mut *write_buf) .map(|_| write_buf.len() < HW); if res.is_ok() && is_write_sleep { @@ -568,10 +497,14 @@ where #[inline] /// Write item to a buf and wake up io task - pub fn write_result( + pub fn write_result( &self, - item: Result>, E>, - ) -> Result::Error>> { + item: Result, E>, + codec: &U, + ) -> Result> + where + U: Encoder, + { let mut flags = self.0.flags.get(); if !flags.intersects(Flags::IO_ERR | Flags::ST_DSP_ERR) { @@ -581,9 +514,7 @@ where let is_write_sleep = write_buf.is_empty(); // encode item - if let Err(err) = - self.0.codec.borrow_mut().encode(item, &mut write_buf) - { + if let Err(err) = codec.encode(item, &mut write_buf) { log::trace!("Codec encoder error: {:?}", err); flags.insert(Flags::DSP_STOP | Flags::ST_DSP_ERR); self.0.flags.set(flags); diff --git a/ntex/src/framed/time.rs b/ntex/src/framed/time.rs index c4865f1b..41ef0886 100644 --- a/ntex/src/framed/time.rs +++ b/ntex/src/framed/time.rs @@ -6,15 +6,15 @@ use crate::framed::State; use crate::rt::time::delay_for; use crate::HashSet; -pub struct Timer(Rc>>); +pub struct Timer(Rc>); -struct Inner { +struct Inner { resolution: Duration, current: Option, - notifications: BTreeMap>>, + notifications: BTreeMap>, } -impl Inner { +impl Inner { fn new(resolution: Duration) -> Self { Inner { resolution, @@ -23,7 +23,7 @@ impl Inner { } } - fn unregister(&mut self, expire: Instant, state: &State) { + fn unregister(&mut self, expire: Instant, state: &State) { if let Some(ref mut states) = self.notifications.get_mut(&expire) { states.remove(state); if states.is_empty() { @@ -33,18 +33,24 @@ impl Inner { } } -impl Clone for Timer { +impl Clone for Timer { fn clone(&self) -> Self { Timer(self.0.clone()) } } -impl Timer { - pub fn with(resolution: Duration) -> Timer { +impl Default for Timer { + fn default() -> Self { + Timer::with(Duration::from_secs(1)) + } +} + +impl Timer { + pub fn with(resolution: Duration) -> Timer { Timer(Rc::new(RefCell::new(Inner::new(resolution)))) } - pub fn register(&self, expire: Instant, previous: Instant, state: &State) { + pub fn register(&self, expire: Instant, previous: Instant, state: &State) { { let mut inner = self.0.borrow_mut(); @@ -59,7 +65,7 @@ impl Timer { let _ = self.now(); } - pub fn unregister(&self, expire: Instant, state: &State) { + pub fn unregister(&self, expire: Instant, state: &State) { self.0.borrow_mut().unregister(expire, state); } diff --git a/ntex/src/framed/write.rs b/ntex/src/framed/write.rs index a76a1be4..a8edae20 100644 --- a/ntex/src/framed/write.rs +++ b/ntex/src/framed/write.rs @@ -3,7 +3,7 @@ use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc, time::Duration}; use bytes::{Buf, BytesMut}; -use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder}; +use crate::codec::{AsyncRead, AsyncWrite}; use crate::framed::State; use crate::rt::time::{delay_for, Delay}; @@ -23,25 +23,21 @@ enum Shutdown { } /// Write io task -pub struct FramedWriteTask +pub struct WriteTask where T: AsyncRead + AsyncWrite + Unpin, - U: Encoder + Decoder, - ::Item: 'static, { st: IoWriteState, io: Rc>, - state: State, + state: State, } -impl FramedWriteTask +impl WriteTask where T: AsyncRead + AsyncWrite + Unpin, - U: Encoder + Decoder, - ::Item: 'static, { /// Create new write io task - pub fn new(io: Rc>, state: State) -> Self { + pub fn new(io: Rc>, state: State) -> Self { Self { io, state, @@ -50,7 +46,7 @@ where } /// Shutdown io stream - pub fn shutdown(io: Rc>, state: State) -> Self { + pub fn shutdown(io: Rc>, state: State) -> Self { let disconnect_timeout = state.disconnect_timeout() as u64; let st = IoWriteState::Shutdown( if disconnect_timeout != 0 { @@ -65,11 +61,9 @@ where } } -impl Future for FramedWriteTask +impl Future for WriteTask where T: AsyncRead + AsyncWrite + Unpin, - U: Encoder + Decoder, - ::Item: 'static, { type Output = (); @@ -204,8 +198,8 @@ pub(super) fn flush( where T: AsyncRead + AsyncWrite + Unpin, { - // log::trace!("flushing framed transport: {}", len); let len = buf.len(); + log::trace!("flushing framed transport: {}", len); if len != 0 { let mut written = 0; diff --git a/ntex/src/http/builder.rs b/ntex/src/http/builder.rs index 20c66e21..05c321be 100644 --- a/ntex/src/http/builder.rs +++ b/ntex/src/http/builder.rs @@ -1,8 +1,6 @@ -use std::fmt; -use std::marker::PhantomData; -use std::rc::Rc; +use std::{error::Error, fmt, marker::PhantomData, rc::Rc}; -use crate::codec::Framed; +use crate::framed::State; use crate::http::body::MessageBody; use crate::http::config::{KeepAlive, ServiceConfig}; use crate::http::error::ResponseError; @@ -34,9 +32,9 @@ impl HttpServiceBuilder> { pub fn new() -> Self { HttpServiceBuilder { keep_alive: KeepAlive::Timeout(5), - client_timeout: 3000, - client_disconnect: 3000, - handshake_timeout: 5000, + client_timeout: 3, + client_disconnect: 3, + handshake_timeout: 5, expect: ExpectHandler, upgrade: None, on_connect: None, @@ -53,12 +51,12 @@ where S::Future: 'static, ::Future: 'static, X: ServiceFactory, - X::Error: ResponseError, + X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, - U: ServiceFactory), Response = ()>, - U::Error: fmt::Display, + U: ServiceFactory, + U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, ::Future: 'static, @@ -71,7 +69,7 @@ where self } - /// Set server client timeout in milliseconds for first request. + /// Set server client timeout in seconds for first request. /// /// Defines a timeout for reading client request header. If a client does not transmit /// the entire set headers within this time, the request is terminated with @@ -80,8 +78,8 @@ where /// To disable timeout set value to 0. /// /// By default client timeout is set to 3 seconds. - pub fn client_timeout(mut self, val: u64) -> Self { - self.client_timeout = val; + pub fn client_timeout(mut self, val: u16) -> Self { + self.client_timeout = val as u64; self } @@ -93,19 +91,19 @@ where /// To disable timeout set value to 0. /// /// By default disconnect timeout is set to 3 seconds. - pub fn disconnect_timeout(mut self, val: u64) -> Self { - self.client_disconnect = val; + pub fn disconnect_timeout(mut self, val: u16) -> Self { + self.client_disconnect = val as u64; self } - /// Set server ssl handshake timeout in milliseconds. + /// Set server ssl handshake timeout in seconds. /// /// Defines a timeout for connection ssl handshake negotiation. /// To disable timeout set value to 0. /// /// By default handshake timeout is set to 5 seconds. - pub fn ssl_handshake_timeout(mut self, val: u64) -> Self { - self.handshake_timeout = val; + pub fn ssl_handshake_timeout(mut self, val: u16) -> Self { + self.handshake_timeout = val as u64; self } @@ -118,7 +116,7 @@ where where F: IntoServiceFactory, X1: ServiceFactory, - X1::Error: ResponseError, + X1::Error: ResponseError + 'static, X1::InitError: fmt::Debug, X1::Future: 'static, ::Future: 'static, @@ -144,10 +142,10 @@ where F: IntoServiceFactory, U1: ServiceFactory< Config = (), - Request = (Request, Framed), + Request = (Request, State, Codec), Response = (), >, - U1::Error: fmt::Display, + U1::Error: fmt::Display + Error + 'static, U1::InitError: fmt::Debug, U1::Future: 'static, ::Future: 'static, diff --git a/ntex/src/http/client/request.rs b/ntex/src/http/client/request.rs index 09c657ae..857037fd 100644 --- a/ntex/src/http/client/request.rs +++ b/ntex/src/http/client/request.rs @@ -376,10 +376,7 @@ impl ClientRequest { pub fn freeze(self) -> Result { let slf = match self.prep_for_sending() { Ok(slf) => slf, - Err(e) => { - println!("E: {:?}", e); - return Err(e.into()); - } + Err(e) => return Err(e.into()), }; let request = FrozenClientRequest { diff --git a/ntex/src/http/client/response.rs b/ntex/src/http/client/response.rs index f80b9e1e..d14dddad 100644 --- a/ntex/src/http/client/response.rs +++ b/ntex/src/http/client/response.rs @@ -1,8 +1,6 @@ use std::cell::{Ref, RefMut}; -use std::fmt; -use std::marker::PhantomData; -use std::pin::Pin; use std::task::{Context, Poll}; +use std::{fmt, marker::PhantomData, mem, pin::Pin}; use bytes::{Bytes, BytesMut}; use futures::{ready, Future, Stream}; @@ -13,18 +11,19 @@ use coo_kie::{Cookie, ParseError as CookieParseError}; use crate::http::error::PayloadError; use crate::http::header::CONTENT_LENGTH; -use crate::http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead}; use crate::http::{HeaderMap, StatusCode, Version}; +use crate::http::{HttpMessage, Payload, ResponseHead}; +use crate::util::Extensions; use super::error::JsonPayloadError; /// Client Response -pub struct ClientResponse { +pub struct ClientResponse { pub(crate) head: ResponseHead, - pub(crate) payload: Payload, + pub(crate) payload: Payload, } -impl HttpMessage for ClientResponse { +impl HttpMessage for ClientResponse { fn message_headers(&self) -> &HeaderMap { &self.head.headers } @@ -59,9 +58,9 @@ impl HttpMessage for ClientResponse { } } -impl ClientResponse { +impl ClientResponse { /// Create new Request instance - pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { + pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { ClientResponse { head, payload } } @@ -89,21 +88,13 @@ impl ClientResponse { } /// Set a body and return previous body value - pub fn map_body(mut self, f: F) -> ClientResponse - where - F: FnOnce(&mut ResponseHead, Payload) -> Payload, - { - let payload = f(&mut self.head, self.payload); - - ClientResponse { - payload, - head: self.head, - } + pub fn set_payload(&mut self, payload: Payload) { + self.payload = payload; } /// Get response's payload - pub fn take_payload(&mut self) -> Payload { - std::mem::take(&mut self.payload) + pub fn take_payload(&mut self) -> Payload { + mem::take(&mut self.payload) } /// Request extensions @@ -119,12 +110,9 @@ impl ClientResponse { } } -impl ClientResponse -where - S: Stream>, -{ +impl ClientResponse { /// Loads http response's body. - pub fn body(&mut self) -> MessageBody { + pub fn body(&mut self) -> MessageBody { MessageBody::new(self) } @@ -135,15 +123,12 @@ where /// /// * content type is not `application/json` /// * content length is greater than 256k - pub fn json(&mut self) -> JsonBody { + pub fn json(&mut self) -> JsonBody { JsonBody::new(self) } } -impl Stream for ClientResponse -where - S: Stream> + Unpin, -{ +impl Stream for ClientResponse { type Item = Result; fn poll_next( @@ -154,7 +139,7 @@ where } } -impl fmt::Debug for ClientResponse { +impl fmt::Debug for ClientResponse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?; writeln!(f, " headers:")?; @@ -166,18 +151,15 @@ impl fmt::Debug for ClientResponse { } /// Future that resolves to a complete http message body. -pub struct MessageBody { +pub struct MessageBody { length: Option, err: Option, - fut: Option>, + fut: Option, } -impl MessageBody -where - S: Stream>, -{ +impl MessageBody { /// Create `MessageBody` for request. - pub fn new(res: &mut ClientResponse) -> MessageBody { + pub fn new(res: &mut ClientResponse) -> MessageBody { let mut len = None; if let Some(l) = res.headers().get(&CONTENT_LENGTH) { if let Ok(s) = l.to_str() { @@ -215,10 +197,7 @@ where } } -impl Future for MessageBody -where - S: Stream> + Unpin, -{ +impl Future for MessageBody { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -244,20 +223,19 @@ where /// /// * content type is not `application/json` /// * content length is greater than 64k -pub struct JsonBody { +pub struct JsonBody { length: Option, err: Option, - fut: Option>, + fut: Option, _t: PhantomData, } -impl JsonBody +impl JsonBody where - S: Stream>, U: DeserializeOwned, { /// Create `JsonBody` for request. - pub fn new(req: &mut ClientResponse) -> Self { + pub fn new(req: &mut ClientResponse) -> Self { // check content-type let json = if let Ok(Some(mime)) = req.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) @@ -299,16 +277,10 @@ where } } -impl Unpin for JsonBody -where - T: Stream> + Unpin, - U: DeserializeOwned, -{ -} +impl Unpin for JsonBody where U: DeserializeOwned {} -impl Future for JsonBody +impl Future for JsonBody where - T: Stream> + Unpin, U: DeserializeOwned, { type Output = Result; @@ -331,14 +303,14 @@ where } } -struct ReadBody { - stream: Payload, +struct ReadBody { + stream: Payload, buf: BytesMut, limit: usize, } -impl ReadBody { - fn new(stream: Payload, limit: usize) -> Self { +impl ReadBody { + fn new(stream: Payload, limit: usize) -> Self { Self { stream, buf: BytesMut::with_capacity(std::cmp::min(limit, 32768)), @@ -347,10 +319,7 @@ impl ReadBody { } } -impl Future for ReadBody -where - S: Stream> + Unpin, -{ +impl Future for ReadBody { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -432,7 +401,7 @@ mod tests { #[ntex_rt::test] async fn test_json_body() { let mut req = TestResponse::default().finish(); - let json = JsonBody::<_, MyObject>::new(&mut req).await; + let json = JsonBody::::new(&mut req).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let mut req = TestResponse::default() @@ -441,7 +410,7 @@ mod tests { header::HeaderValue::from_static("application/text"), ) .finish(); - let json = JsonBody::<_, MyObject>::new(&mut req).await; + let json = JsonBody::::new(&mut req).await; assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType)); let mut req = TestResponse::default() @@ -455,7 +424,7 @@ mod tests { ) .finish(); - let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await; + let json = JsonBody::::new(&mut req).limit(100).await; assert!(json_eq( json.err().unwrap(), JsonPayloadError::Payload(PayloadError::Overflow) @@ -473,7 +442,7 @@ mod tests { .set_payload(Bytes::from_static(b"{\"name\": \"test\"}")) .finish(); - let json = JsonBody::<_, MyObject>::new(&mut req).await; + let json = JsonBody::::new(&mut req).await; assert_eq!( json.ok().unwrap(), MyObject { diff --git a/ntex/src/http/client/sender.rs b/ntex/src/http/client/sender.rs index 3236897e..fcc638f1 100644 --- a/ntex/src/http/client/sender.rs +++ b/ntex/src/http/client/sender.rs @@ -19,9 +19,7 @@ use crate::rt::time::{delay_for, Delay}; #[cfg(feature = "compress")] use crate::http::encoding::Decoder; #[cfg(feature = "compress")] -use crate::http::header::ContentEncoding; -#[cfg(feature = "compress")] -use crate::http::{Payload, PayloadStream}; +use crate::http::Payload; use super::error::{FreezeRequestError, InvalidUrl, SendRequestError}; use super::response::ClientResponse; @@ -74,10 +72,6 @@ impl SendClientRequest { } impl Future for SendClientRequest { - #[cfg(feature = "compress")] - type Output = - Result>>, SendRequestError>; - #[cfg(not(feature = "compress"))] type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -95,20 +89,15 @@ impl Future for SendClientRequest { let res = futures::ready!(Pin::new(send).poll(cx)); #[cfg(feature = "compress")] - let res = res.map(|res| { - res.map_body(|head, payload| { - if *_response_decompress { - Payload::Stream(Decoder::from_headers( - payload, - &head.headers, - )) - } else { - Payload::Stream(Decoder::new( - payload, - ContentEncoding::Identity, - )) - } - }) + let res = res.map(|mut res| { + if *_response_decompress { + let payload = res.take_payload(); + res.set_payload(Payload::from_stream(Decoder::from_headers( + payload, + &res.head.headers, + ))) + } + res }); Poll::Ready(res) diff --git a/ntex/src/http/config.rs b/ntex/src/http/config.rs index ff3f5a18..4d35676e 100644 --- a/ntex/src/http/config.rs +++ b/ntex/src/http/config.rs @@ -1,14 +1,12 @@ -use std::cell::UnsafeCell; -use std::fmt; -use std::fmt::Write; -use std::ptr::copy_nonoverlapping; -use std::rc::Rc; -use std::time::Duration; +use std::{ + cell::UnsafeCell, fmt, fmt::Write, ptr::copy_nonoverlapping, rc::Rc, time::Duration, +}; use bytes::BytesMut; use futures::{future, FutureExt}; use time::OffsetDateTime; +use crate::framed::Timer; use crate::rt::time::{delay_for, delay_until, Delay, Instant}; // "Sun, 06 Nov 1994 08:49:37 GMT".len() @@ -45,12 +43,13 @@ impl From> for KeepAlive { pub struct ServiceConfig(pub(super) Rc); pub(super) struct Inner { - pub(super) keep_alive: Option, + pub(super) keep_alive: u64, pub(super) client_timeout: u64, pub(super) client_disconnect: u64, pub(super) ka_enabled: bool, pub(super) timer: DateService, pub(super) ssl_handshake_timeout: u64, + pub(super) timer_h1: Timer, } impl Clone for ServiceConfig { @@ -79,9 +78,9 @@ impl ServiceConfig { KeepAlive::Disabled => (0, false), }; let keep_alive = if ka_enabled && keep_alive > 0 { - Some(Duration::from_secs(keep_alive)) + keep_alive } else { - None + 0 }; ServiceConfig(Rc::new(Inner { @@ -91,6 +90,7 @@ impl ServiceConfig { client_disconnect, ssl_handshake_timeout, timer: DateService::new(), + timer_h1: Timer::default(), })) } } @@ -99,11 +99,12 @@ pub(super) struct DispatcherConfig { pub(super) service: S, pub(super) expect: X, pub(super) upgrade: Option, - pub(super) keep_alive: Option, + pub(super) keep_alive: u64, pub(super) client_timeout: u64, pub(super) client_disconnect: u64, pub(super) ka_enabled: bool, pub(super) timer: DateService, + pub(super) timer_h1: Timer, } impl DispatcherConfig { @@ -122,6 +123,7 @@ impl DispatcherConfig { client_disconnect: cfg.0.client_disconnect, ka_enabled: cfg.0.ka_enabled, timer: cfg.0.timer.clone(), + timer_h1: cfg.0.timer_h1.clone(), } } @@ -130,37 +132,12 @@ impl DispatcherConfig { self.ka_enabled } - /// Client timeout for first request. - pub(super) fn client_timer(&self) -> Option { - let delay_time = self.client_timeout; - if delay_time != 0 { - Some(delay_until( - self.timer.now() + Duration::from_millis(delay_time), - )) - } else { - None - } - } - - /// Client disconnect timer - pub(super) fn client_disconnect_timer(&self) -> Option { - let delay = self.client_disconnect; - if delay != 0 { - Some(self.timer.now() + Duration::from_millis(delay)) - } else { - None - } - } - - /// Return state of connection keep-alive timer - pub(super) fn keep_alive_timer_enabled(&self) -> bool { - self.keep_alive.is_some() - } - /// Return keep-alive timer delay is configured. pub(super) fn keep_alive_timer(&self) -> Option { - if let Some(ka) = self.keep_alive { - Some(delay_until(self.timer.now() + ka)) + if self.keep_alive != 0 { + Some(delay_until( + self.timer.now() + Duration::from_secs(self.keep_alive), + )) } else { None } @@ -168,8 +145,8 @@ impl DispatcherConfig { /// Keep-alive expire time pub(super) fn keep_alive_expire(&self) -> Option { - if let Some(ka) = self.keep_alive { - Some(self.timer.now() + ka) + if self.keep_alive != 0 { + Some(self.timer.now() + Duration::from_secs(self.keep_alive)) } else { None } diff --git a/ntex/src/http/error.rs b/ntex/src/http/error.rs index c232a5a5..a4feb30b 100644 --- a/ntex/src/http/error.rs +++ b/ntex/src/http/error.rs @@ -16,7 +16,7 @@ use super::body::Body; use super::response::Response; /// Error that can be converted to `Response` -pub trait ResponseError: fmt::Display + fmt::Debug + 'static { +pub trait ResponseError: fmt::Display + fmt::Debug { /// Create response for error /// /// Internal server error is generated by default. @@ -32,6 +32,12 @@ pub trait ResponseError: fmt::Display + fmt::Debug + 'static { } } +impl<'a, T: ResponseError> ResponseError for &'a T { + fn error_response(&self) -> Response { + (*self).error_response() + } +} + impl From for Response { fn from(err: T) -> Response { let resp = err.error_response(); @@ -180,8 +186,9 @@ pub enum DispatchError { /// Service error Service(Box), + #[from(ignore)] /// Upgrade service error - Upgrade, + Upgrade(Box), /// An `io::Error` that occurred while trying to read or write to a network /// stream. @@ -192,6 +199,11 @@ pub enum DispatchError { #[display(fmt = "Parse error: {}", _0)] Parse(ParseError), + /// Http response encoding error. + #[display(fmt = "Encode error: {}", _0)] + #[from(ignore)] + Encode(io::Error), + /// Http/2 error #[display(fmt = "{}", _0)] H2(h2::Error), @@ -212,6 +224,10 @@ pub enum DispatchError { #[display(fmt = "Malformed request")] MalformedRequest, + /// Response body processing error + #[display(fmt = "Response body processing error: {}", _0)] + ResponsePayload(Box), + /// Internal error #[display(fmt = "Internal error")] InternalError, diff --git a/ntex/src/http/h1/client.rs b/ntex/src/http/h1/client.rs index 879c08ad..eb019233 100644 --- a/ntex/src/http/h1/client.rs +++ b/ntex/src/http/h1/client.rs @@ -1,4 +1,4 @@ -use std::io; +use std::{cell::Cell, cell::RefCell, io}; use bitflags::bitflags; use bytes::{Bytes, BytesMut}; @@ -34,12 +34,12 @@ pub struct ClientPayloadCodec { struct ClientCodecInner { timer: DateService, decoder: decoder::MessageDecoder, - payload: Option, - version: Version, - ctype: ConnectionType, + payload: RefCell>, + version: Cell, + ctype: Cell, // encoder part - flags: Flags, + flags: Cell, encoder: encoder::MessageEncoder, } @@ -63,11 +63,10 @@ impl ClientCodec { inner: ClientCodecInner { timer, decoder: decoder::MessageDecoder::default(), - payload: None, - version: Version::HTTP_11, - ctype: ConnectionType::Close, - - flags, + payload: RefCell::new(None), + version: Cell::new(Version::HTTP_11), + ctype: Cell::new(ConnectionType::Close), + flags: Cell::new(flags), encoder: encoder::MessageEncoder::default(), }, } @@ -75,19 +74,19 @@ impl ClientCodec { /// Check if request is upgrade pub fn upgrade(&self) -> bool { - self.inner.ctype == ConnectionType::Upgrade + self.inner.ctype.get() == ConnectionType::Upgrade } /// Check if last response is keep-alive pub fn keepalive(&self) -> bool { - self.inner.ctype == ConnectionType::KeepAlive + self.inner.ctype.get() == ConnectionType::KeepAlive } /// Check last request's message type pub fn message_type(&self) -> MessageType { - if self.inner.flags.contains(Flags::STREAM) { + if self.inner.flags.get().contains(Flags::STREAM) { MessageType::Stream - } else if self.inner.payload.is_none() { + } else if self.inner.payload.borrow().is_none() { MessageType::None } else { MessageType::Payload @@ -103,7 +102,7 @@ impl ClientCodec { impl ClientPayloadCodec { /// Check if last response is keep-alive pub fn keepalive(&self) -> bool { - self.inner.ctype == ConnectionType::KeepAlive + self.inner.ctype.get() == ConnectionType::KeepAlive } /// Transform payload codec to a message codec @@ -116,30 +115,37 @@ impl Decoder for ClientCodec { type Item = ResponseHead; type Error = ParseError; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set"); + fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { + debug_assert!( + !self.inner.payload.borrow().is_some(), + "Payload decoder is set" + ); if let Some((req, payload)) = self.inner.decoder.decode(src)? { if let Some(ctype) = req.ctype() { // do not use peer's keep-alive - self.inner.ctype = if ctype == ConnectionType::KeepAlive { - self.inner.ctype - } else { - ctype + if ctype != ConnectionType::KeepAlive { + self.inner.ctype.set(ctype); }; } - if !self.inner.flags.contains(Flags::HEAD) { + if !self.inner.flags.get().contains(Flags::HEAD) { match payload { - PayloadType::None => self.inner.payload = None, - PayloadType::Payload(pl) => self.inner.payload = Some(pl), + PayloadType::None => { + self.inner.payload.borrow_mut().take(); + } + PayloadType::Payload(pl) => { + *self.inner.payload.borrow_mut() = Some(pl) + } PayloadType::Stream(pl) => { - self.inner.payload = Some(pl); - self.inner.flags.insert(Flags::STREAM); + *self.inner.payload.borrow_mut() = Some(pl); + let mut flags = self.inner.flags.get(); + flags.insert(Flags::STREAM); + self.inner.flags.set(flags); } } } else { - self.inner.payload = None; + self.inner.payload.borrow_mut().take(); } reserve_readbuf(src); Ok(Some(req)) @@ -153,19 +159,27 @@ impl Decoder for ClientPayloadCodec { type Item = Option; type Error = PayloadError; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { debug_assert!( - self.inner.payload.is_some(), + self.inner.payload.borrow().is_some(), "Payload decoder is not specified" ); - Ok(match self.inner.payload.as_mut().unwrap().decode(src)? { + let item = self + .inner + .payload + .borrow_mut() + .as_mut() + .unwrap() + .decode(src)?; + + Ok(match item { Some(PayloadItem::Chunk(chunk)) => { reserve_readbuf(src); Some(Some(chunk)) } Some(PayloadItem::Eof) => { - self.inner.payload.take(); + self.inner.payload.borrow_mut().take(); Some(None) } None => None, @@ -177,23 +191,19 @@ impl Encoder for ClientCodec { type Item = Message<(RequestHeadType, BodySize)>; type Error = io::Error; - fn encode( - &mut self, - item: Self::Item, - dst: &mut BytesMut, - ) -> Result<(), Self::Error> { + fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { Message::Item((mut head, length)) => { - let inner = &mut self.inner; - inner.version = head.as_ref().version; - inner - .flags - .set(Flags::HEAD, head.as_ref().method == Method::HEAD); + let inner = &self.inner; + inner.version.set(head.as_ref().version); + let mut flags = inner.flags.get(); + flags.set(Flags::HEAD, head.as_ref().method == Method::HEAD); + inner.flags.set(flags); // connection status - inner.ctype = match head.as_ref().connection_type() { + inner.ctype.set(match head.as_ref().connection_type() { ConnectionType::KeepAlive => { - if inner.flags.contains(Flags::KEEPALIVE_ENABLED) { + if inner.flags.get().contains(Flags::KEEPALIVE_ENABLED) { ConnectionType::KeepAlive } else { ConnectionType::Close @@ -201,16 +211,16 @@ impl Encoder for ClientCodec { } ConnectionType::Upgrade => ConnectionType::Upgrade, ConnectionType::Close => ConnectionType::Close, - }; + }); inner.encoder.encode( dst, &mut head, false, false, - inner.version, + inner.version.get(), length, - inner.ctype, + inner.ctype.get(), &inner.timer, )?; } diff --git a/ntex/src/http/h1/codec.rs b/ntex/src/http/h1/codec.rs index 10c57610..88a5f06a 100644 --- a/ntex/src/http/h1/codec.rs +++ b/ntex/src/http/h1/codec.rs @@ -1,4 +1,4 @@ -use std::{fmt, io}; +use std::{cell::Cell, fmt, io}; use bitflags::bitflags; use bytes::BytesMut; @@ -12,15 +12,13 @@ use crate::http::message::ConnectionType; use crate::http::request::Request; use crate::http::response::Response; -use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; -use super::{decoder, encoder}; -use super::{Message, MessageType}; +use super::{decoder, decoder::PayloadType, encoder, Message}; bitflags! { struct Flags: u8 { const HEAD = 0b0000_0001; - const KEEPALIVE_ENABLED = 0b0000_0010; - const STREAM = 0b0000_0100; + const STREAM = 0b0000_0010; + const KEEPALIVE_ENABLED = 0b0000_0100; } } @@ -28,12 +26,11 @@ bitflags! { pub struct Codec { timer: DateService, decoder: decoder::MessageDecoder, - payload: Option, - version: Version, - ctype: ConnectionType, + version: Cell, + ctype: Cell, // encoder part - flags: Flags, + flags: Cell, encoder: encoder::MessageEncoder>, } @@ -43,6 +40,19 @@ impl Default for Codec { } } +impl Clone for Codec { + fn clone(&self) -> Self { + Codec { + timer: self.timer.clone(), + decoder: self.decoder.clone(), + version: self.version.clone(), + ctype: self.ctype.clone(), + flags: self.flags.clone(), + encoder: self.encoder.clone(), + } + } +} + impl fmt::Debug for Codec { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "h1::Codec({:?})", self.flags) @@ -61,12 +71,11 @@ impl Codec { }; Codec { - flags, timer, + flags: Cell::new(flags), decoder: decoder::MessageDecoder::default(), - payload: None, - version: Version::HTTP_11, - ctype: ConnectionType::Close, + version: Cell::new(Version::HTTP_11), + ctype: Cell::new(ConnectionType::Close), encoder: encoder::MessageEncoder::default(), } } @@ -74,31 +83,19 @@ impl Codec { #[inline] /// Check if request is upgrade pub fn upgrade(&self) -> bool { - self.ctype == ConnectionType::Upgrade + self.ctype.get() == ConnectionType::Upgrade } #[inline] /// Check if last response is keep-alive pub fn keepalive(&self) -> bool { - self.ctype == ConnectionType::KeepAlive + self.ctype.get() == ConnectionType::KeepAlive } #[inline] /// Check if keep-alive enabled on server level pub fn keepalive_enabled(&self) -> bool { - self.flags.contains(Flags::KEEPALIVE_ENABLED) - } - - #[inline] - /// Check last request's message type - pub fn message_type(&self) -> MessageType { - if self.flags.contains(Flags::STREAM) { - MessageType::Stream - } else if self.payload.is_none() { - MessageType::None - } else { - MessageType::Payload - } + self.flags.get().contains(Flags::KEEPALIVE_ENABLED) } #[inline] @@ -106,41 +103,36 @@ impl Codec { pub fn set_date_header(&self, dst: &mut BytesMut) { self.timer.set_date_header(dst) } + + fn insert_flags(&self, f: Flags) { + let mut flags = self.flags.get(); + flags.insert(f); + self.flags.set(flags); + } } impl Decoder for Codec { - type Item = Message; + type Item = (Request, PayloadType); type Error = ParseError; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - if let Some(ref mut payload) = self.payload { - Ok(match payload.decode(src)? { - Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))), - Some(PayloadItem::Eof) => { - self.payload.take(); - Some(Message::Chunk(None)) - } - None => None, - }) - } else if let Some((req, payload)) = self.decoder.decode(src)? { + fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { + if let Some((req, payload)) = self.decoder.decode(src)? { let head = req.head(); - self.flags.set(Flags::HEAD, head.method == Method::HEAD); - self.version = head.version; - self.ctype = head.connection_type(); - if self.ctype == ConnectionType::KeepAlive - && !self.flags.contains(Flags::KEEPALIVE_ENABLED) + let mut flags = self.flags.get(); + flags.set(Flags::HEAD, head.method == Method::HEAD); + self.flags.set(flags); + self.version.set(head.version); + self.ctype.set(head.connection_type()); + if self.ctype.get() == ConnectionType::KeepAlive + && !flags.contains(Flags::KEEPALIVE_ENABLED) { - self.ctype = ConnectionType::Close + self.ctype.set(ConnectionType::Close) } - match payload { - PayloadType::None => self.payload = None, - PayloadType::Payload(pl) => self.payload = Some(pl), - PayloadType::Stream(pl) => { - self.payload = Some(pl); - self.flags.insert(Flags::STREAM); - } + + if let PayloadType::Stream(_) = payload { + self.insert_flags(Flags::STREAM) } - Ok(Some(Message::Item(req))) + Ok(Some((req, payload))) } else { Ok(None) } @@ -151,36 +143,28 @@ impl Encoder for Codec { type Item = Message<(Response<()>, BodySize)>; type Error = io::Error; - fn encode( - &mut self, - item: Self::Item, - dst: &mut BytesMut, - ) -> Result<(), Self::Error> { + fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { Message::Item((mut res, length)) => { // set response version - res.head_mut().version = self.version; + res.head_mut().version = self.version.get(); // connection status - self.ctype = if let Some(ct) = res.head().ctype() { - if ct == ConnectionType::KeepAlive { - self.ctype - } else { - ct + if let Some(ct) = res.head().ctype() { + if ct != ConnectionType::KeepAlive { + self.ctype.set(ct) } - } else { - self.ctype - }; + } // encode message self.encoder.encode( dst, &mut res, - self.flags.contains(Flags::HEAD), - self.flags.contains(Flags::STREAM), - self.version, + self.flags.get().contains(Flags::HEAD), + self.flags.get().contains(Flags::STREAM), + self.version.get(), length, - self.ctype, + self.ctype.get(), &self.timer, )?; // self.headers_size = (dst.len() - len) as u32; @@ -198,22 +182,25 @@ impl Encoder for Codec { #[cfg(test)] mod tests { - use bytes::BytesMut; + use bytes::{Bytes, BytesMut}; use super::*; - use crate::http::{HttpMessage, Method}; + use crate::http::{h1::PayloadItem, HttpMessage, Method}; #[test] fn test_http_request_chunked_payload_and_next_message() { - let mut codec = Codec::default(); + let codec = Codec::default(); assert!(format!("{:?}", codec).contains("h1::Codec")); let mut buf = BytesMut::from( "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n", ); - let item = codec.decode(&mut buf).unwrap().unwrap(); - let req = item.message(); + let (req, pl) = codec.decode(&mut buf).unwrap().unwrap(); + let pl = match pl { + PayloadType::Payload(pl) => pl, + _ => panic!(), + }; assert_eq!(req.method(), Method::GET); assert!(req.chunked().unwrap()); @@ -225,22 +212,21 @@ mod tests { .iter(), ); - let msg = codec.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"data"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg, PayloadItem::Chunk(Bytes::from_static(b"data"))); - let msg = codec.decode(&mut buf).unwrap().unwrap(); - assert_eq!(msg.chunk().as_ref(), b"line"); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg, PayloadItem::Chunk(Bytes::from_static(b"line"))); - let msg = codec.decode(&mut buf).unwrap().unwrap(); - assert!(msg.eof()); + let msg = pl.decode(&mut buf).unwrap().unwrap(); + assert_eq!(msg, PayloadItem::Eof); // decode next message - let item = codec.decode(&mut buf).unwrap().unwrap(); - let req = item.message(); + let (req, _pl) = codec.decode(&mut buf).unwrap().unwrap(); assert_eq!(*req.method(), Method::POST); assert!(req.chunked().unwrap()); - let mut codec = Codec::default(); + let codec = Codec::default(); let mut buf = BytesMut::from( "GET /test HTTP/1.1\r\n\ connection: upgrade\r\n\r\n", diff --git a/ntex/src/http/h1/decoder.rs b/ntex/src/http/h1/decoder.rs index 6246858f..e9323ded 100644 --- a/ntex/src/http/h1/decoder.rs +++ b/ntex/src/http/h1/decoder.rs @@ -1,12 +1,10 @@ -use std::convert::TryFrom; -use std::marker::PhantomData; -use std::mem::MaybeUninit; -use std::task::Poll; +use std::{ + cell::Cell, convert::TryFrom, marker::PhantomData, mem::MaybeUninit, task::Poll, +}; use bytes::{Buf, Bytes, BytesMut}; use http::header::{HeaderName, HeaderValue}; use http::{header, Method, StatusCode, Uri, Version}; -use log::{debug, error, trace}; use crate::codec::Decoder; use crate::http::error::ParseError; @@ -23,7 +21,7 @@ pub(super) struct MessageDecoder(PhantomData); #[derive(Debug)] /// Incoming request type -pub(super) enum PayloadType { +pub enum PayloadType { None, Payload(PayloadDecoder), Stream(PayloadDecoder), @@ -35,11 +33,17 @@ impl Default for MessageDecoder { } } +impl Clone for MessageDecoder { + fn clone(&self) -> Self { + MessageDecoder(PhantomData) + } +} + impl Decoder for MessageDecoder { type Item = (T, PayloadType); type Error = ParseError; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { T::decode(src) } } @@ -91,11 +95,11 @@ pub(super) trait MessageType: Sized { content_length = Some(len); } } else { - debug!("illegal Content-Length: {:?}", s); + log::debug!("illegal Content-Length: {:?}", s); return Err(ParseError::Header); } } else { - debug!("illegal Content-Length: {:?}", value); + log::debug!("illegal Content-Length: {:?}", value); return Err(ParseError::Header); } } @@ -290,7 +294,7 @@ impl MessageType for ResponseHead { } httparse::Status::Partial => { return if src.len() >= MAX_BUFFER_SIZE { - error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); + log::error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); Err(ParseError::TooLarge) } else { Ok(None) @@ -351,7 +355,7 @@ impl HeaderIndex { #[derive(Debug, Clone, PartialEq)] /// Http payload item -pub(super) enum PayloadItem { +pub enum PayloadItem { Chunk(Bytes), Eof, } @@ -361,29 +365,31 @@ pub(super) enum PayloadItem { /// If a message body does not include a Transfer-Encoding, it *should* /// include a Content-Length header. #[derive(Debug, Clone, PartialEq)] -pub(super) struct PayloadDecoder { - kind: Kind, +pub struct PayloadDecoder { + kind: Cell, } impl PayloadDecoder { pub(super) fn length(x: u64) -> PayloadDecoder { PayloadDecoder { - kind: Kind::Length(x), + kind: Cell::new(Kind::Length(x)), } } pub(super) fn chunked() -> PayloadDecoder { PayloadDecoder { - kind: Kind::Chunked(ChunkedState::Size, 0), + kind: Cell::new(Kind::Chunked(ChunkedState::Size, 0)), } } pub(super) fn eof() -> PayloadDecoder { - PayloadDecoder { kind: Kind::Eof } + PayloadDecoder { + kind: Cell::new(Kind::Eof), + } } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq)] enum Kind { /// A Reader used when a Content-Length header is passed with a positive /// integer. @@ -407,7 +413,7 @@ enum Kind { Eof, } -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Copy, Clone)] enum ChunkedState { Size, SizeLws, @@ -425,8 +431,10 @@ impl Decoder for PayloadDecoder { type Item = PayloadItem; type Error = ParseError; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - match self.kind { + fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { + let mut kind = self.kind.get(); + + match kind { Kind::Length(ref mut remaining) => { if *remaining == 0 { Ok(Some(PayloadItem::Eof)) @@ -443,30 +451,35 @@ impl Decoder for PayloadDecoder { buf = src.split_to(*remaining as usize).freeze(); *remaining = 0; }; - trace!("Length read: {}", buf.len()); + self.kind.set(kind); + log::trace!("Length read: {}", buf.len()); Ok(Some(PayloadItem::Chunk(buf))) } } Kind::Chunked(ref mut state, ref mut size) => { - loop { + let result = loop { let mut buf = None; // advances the chunked state *state = match state.step(src, size, &mut buf) { - Poll::Pending => return Ok(None), + Poll::Pending => break Ok(None), Poll::Ready(Ok(state)) => state, - Poll::Ready(Err(e)) => return Err(e), + Poll::Ready(Err(e)) => break Err(e), }; + if *state == ChunkedState::End { - trace!("End of chunked stream"); - return Ok(Some(PayloadItem::Eof)); + log::trace!("End of chunked stream"); + break Ok(Some(PayloadItem::Eof)); } + if let Some(buf) = buf { - return Ok(Some(PayloadItem::Chunk(buf))); + break Ok(Some(PayloadItem::Chunk(buf))); } if src.is_empty() { - return Ok(None); + break Ok(None); } - } + }; + self.kind.set(kind); + result } Kind::Eof => { if src.is_empty() { @@ -544,7 +557,7 @@ impl ChunkedState { } fn read_size_lws(rdr: &mut BytesMut) -> Poll> { - trace!("read_size_lws"); + log::trace!("read_size_lws"); match byte!(rdr) { // LWS can follow the chunk size, but no more digits can come b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), @@ -577,7 +590,7 @@ impl ChunkedState { rem: &mut u64, buf: &mut Option, ) -> Poll> { - trace!("Chunked read, remaining={:?}", rem); + log::trace!("Chunked read, remaining={:?}", rem); let len = rdr.len() as u64; if len == 0 { @@ -693,7 +706,7 @@ mod tests { fn test_parse() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n"); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); match reader.decode(&mut buf) { Ok(Some((req, _))) => { assert_eq!(req.version(), Version::HTTP_11); @@ -708,7 +721,7 @@ mod tests { fn test_parse_partial() { let mut buf = BytesMut::from("PUT /test HTTP/1"); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); assert!(reader.decode(&mut buf).unwrap().is_none()); buf.extend(b".1\r\n\r\n"); @@ -722,7 +735,7 @@ mod tests { fn test_parse_post() { let mut buf = BytesMut::from("POST /test2 HTTP/1.0\r\n\r\n"); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); assert_eq!(req.version(), Version::HTTP_10); assert_eq!(*req.method(), Method::POST); @@ -734,9 +747,9 @@ mod tests { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); + let pl = pl.unwrap(); assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -751,9 +764,9 @@ mod tests { let mut buf = BytesMut::from("\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); + let pl = pl.unwrap(); assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -766,7 +779,7 @@ mod tests { #[test] fn test_parse_partial_eof() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); assert!(reader.decode(&mut buf).unwrap().is_none()); buf.extend(b"\r\n"); @@ -780,7 +793,7 @@ mod tests { fn test_headers_split_field() { let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n"); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); assert! { reader.decode(&mut buf).unwrap().is_none() } buf.extend(b"t"); @@ -810,7 +823,7 @@ mod tests { Set-Cookie: c1=cookie1\r\n\ Set-Cookie: c2=cookie2\r\n\r\n", ); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (req, _) = reader.decode(&mut buf).unwrap().unwrap(); let val: Vec<_> = req @@ -1037,7 +1050,7 @@ mod tests { upgrade: websocket\r\n\r\n\ some raw data", ); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); assert_eq!(req.head().connection_type(), ConnectionType::Upgrade); assert!(req.upgrade()); @@ -1086,9 +1099,9 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n", ); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); + let pl = pl.unwrap(); assert!(req.chunked().unwrap()); buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); @@ -1109,9 +1122,9 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n", ); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); + let pl = pl.unwrap(); assert!(req.chunked().unwrap()); buf.extend( @@ -1140,9 +1153,9 @@ mod tests { transfer-encoding: chunked\r\n\r\n", ); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (req, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); + let pl = pl.unwrap(); assert!(req.chunked().unwrap()); buf.extend(b"4\r\n1111\r\n"); @@ -1185,9 +1198,9 @@ mod tests { transfer-encoding: chunked\r\n\r\n"[..], ); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); + let pl = pl.unwrap(); assert!(msg.chunked().unwrap()); buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") @@ -1203,9 +1216,9 @@ mod tests { fn test_response_http10_read_until_eof() { let mut buf = BytesMut::from(&"HTTP/1.0 200 Ok\r\n\r\ntest data"[..]); - let mut reader = MessageDecoder::::default(); + let reader = MessageDecoder::::default(); let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap(); - let mut pl = pl.unwrap(); + let pl = pl.unwrap(); let chunk = pl.decode(&mut buf).unwrap().unwrap(); assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"test data"))); diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index 6ced0d4f..e6ea85b0 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1,1297 +1,606 @@ -use std::{fmt, io, mem, net, pin::Pin, rc::Rc, task::Context, task::Poll}; +//! Framed transport dispatcher +use std::error::Error; +use std::task::{Context, Poll}; +use std::{ + cell::RefCell, fmt, future::Future, marker::PhantomData, net, pin::Pin, rc::Rc, + time::Duration, time::Instant, +}; -use bitflags::bitflags; -use bytes::{Buf, BytesMut}; -use futures::{ready, Future}; +use bytes::Bytes; -use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; -use crate::http::body::{Body, BodySize, MessageBody, ResponseBody}; +use crate::codec::{AsyncRead, AsyncWrite, Decoder}; +use crate::framed::{ReadTask, State as IoState, WriteTask}; +use crate::service::Service; + +use crate::http; +use crate::http::body::{BodySize, MessageBody, ResponseBody}; use crate::http::config::DispatcherConfig; -use crate::http::error::{DispatchError, ParseError, PayloadError, ResponseError}; +use crate::http::error::{DispatchError, PayloadError, ResponseError}; use crate::http::helpers::DataFactory; use crate::http::request::Request; use crate::http::response::Response; -use crate::rt::time::{delay_until, Delay, Instant}; -use crate::Service; -use super::codec::Codec; +use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; use super::payload::{Payload, PayloadSender, PayloadStatus}; -use super::{Message, MessageType}; +use super::{codec::Codec, Message}; -const READ_LW_BUFFER_SIZE: usize = 1024; -const READ_HW_BUFFER_SIZE: usize = 4096; -const WRITE_LW_BUFFER_SIZE: usize = 2048; -const WRITE_HW_BUFFER_SIZE: usize = 8192; -const BUFFER_SIZE: usize = 32_768; - -bitflags! { +bitflags::bitflags! { pub struct Flags: u16 { /// We parsed one complete request message - const STARTED = 0b0000_0001; + const STARTED = 0b0000_0001; /// Keep-alive is enabled on current connection - const KEEPALIVE = 0b0000_0010; - /// Socket is disconnected, read or write side - const DISCONNECT = 0b0000_0100; - /// Connection is upgraded or request parse error (bad request) - const STOP_READING = 0b0000_1000; - /// Shutdown is in process (flushing and io shutdown timer) - const SHUTDOWN = 0b0001_0000; - /// Io shutdown process started - const SHUTDOWN_IO = 0b0010_0000; - /// Shutdown timer is started - const SHUTDOWN_TM = 0b0100_0000; - /// Connection is upgraded - const UPGRADE = 0b1000_0000; - /// All data has been read - const READ_EOF = 0b0001_0000_0000; - /// Keep alive is enabled - const HAS_KEEPALIVE = 0b0010_0000_0000; + const KEEPALIVE = 0b0000_0010; + /// Upgrade request + const UPGRADE = 0b0000_0100; } } pin_project_lite::pin_project! { -/// Dispatcher for HTTP/1.1 protocol -pub struct Dispatcher -where - S: Service, - S::Error: ResponseError, - B: MessageBody, - X: Service, - X::Error: ResponseError, - U: Service), Response = ()>, - U::Error: fmt::Display, -{ - #[pin] - call: CallState, - inner: InnerDispatcher, - #[pin] - upgrade: Option, -} + /// Dispatcher for HTTP/1.1 protocol + pub struct Dispatcher { + #[pin] + call: CallState, + st: State, + inner: DispatcherInner, + } } -struct InnerDispatcher -where - S: Service, - S::Error: ResponseError, - B: MessageBody, - X: Service, - X::Error: ResponseError, - U: Service), Response = ()>, - U::Error: fmt::Display, -{ - config: Rc>, - on_connect: Option>, - peer_addr: Option, - flags: Flags, - error: Option, - - res_payload: Option>, - req_payload: Option, - - ka_expire: Instant, - ka_timer: Option, - - io: Option, - read_buf: BytesMut, - write_buf: BytesMut, - codec: Codec, -} - -enum DispatcherMessage { - Request(Request), - Upgrade(Request), - Error(Response<()>), -} - -#[derive(Clone, Copy, PartialEq, Eq)] -enum PollWrite { - /// allowed to process next request - AllowNext, - /// write buffer is full - Pending, - /// waiting for response stream (app response) - /// or write buffer is full - PendingResponse, +enum State { + Call, + ReadRequest, + ReadPayload, + SendPayload { body: ResponseBody }, + Stop, } pin_project_lite::pin_project! { #[project = CallStateProject] - enum CallState { - Io, - Expect { #[pin] fut: X::Future }, + enum CallState { + None, Service { #[pin] fut: S::Future }, + Expect { #[pin] fut: X::Future }, + Upgrade { #[pin] fut: U::Future }, } } -enum CallProcess { - /// next call is available - Next(CallState), - /// waiting for service call response completion - Pending, - /// call queue is empty - Io, - /// Upgrade connection - Upgrade(U::Future), +struct DispatcherInner { + flags: Flags, + codec: Codec, + config: Rc>, + state: IoState, + expire: Instant, + error: Option, + payload: Option<(PayloadDecoder, PayloadSender)>, + peer_addr: Option, + on_connect_data: Option>, + _t: PhantomData<(S, B)>, } -impl Dispatcher +#[derive(Copy, Clone, PartialEq, Eq)] +enum PollPayloadStatus { + Done, + Updated, + Pending, + Dropped, +} + +impl Dispatcher where - T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, X: Service, X::Error: ResponseError, - U: Service), Response = ()>, - U::Error: fmt::Display, + U: Service, + U::Error: Error + fmt::Display, { - /// Create http/1 dispatcher. - pub(in crate::http) fn new( - config: Rc>, - stream: T, - peer_addr: Option, - on_connect: Option>, - ) -> Self { - let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled()); - // slow request timer - let timeout = config.client_timer(); - - Dispatcher::with_timeout( - config, - stream, - codec, - BytesMut::with_capacity(READ_HW_BUFFER_SIZE), - timeout, - peer_addr, - on_connect, - ) - } - - /// Create http/1 dispatcher with slow request timeout. - pub(in crate::http) fn with_timeout( - config: Rc>, + /// Construct new `Dispatcher` instance with outgoing messages stream. + pub(in crate::http) fn new( io: T, - codec: Codec, - read_buf: BytesMut, - timeout: Option, + config: Rc>, peer_addr: Option, - on_connect: Option>, - ) -> Self { - let keepalive = config.keep_alive_enabled(); - let mut flags = if keepalive { - Flags::KEEPALIVE | Flags::READ_EOF - } else { - Flags::READ_EOF - }; - if config.keep_alive_timer_enabled() { - flags |= Flags::HAS_KEEPALIVE; + on_connect_data: Option>, + ) -> Self + where + T: AsyncRead + AsyncWrite + Unpin + 'static, + { + let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled()); + + let state = IoState::new(); + state.set_disconnect_timeout(config.client_disconnect as u16); + + let mut expire = config.timer_h1.now(); + let io = Rc::new(RefCell::new(io)); + + // slow-request timer + if config.client_timeout != 0 { + expire += Duration::from_secs(config.client_timeout); + config.timer_h1.register(expire, expire, &state); } - // keep-alive timer - let (ka_expire, ka_timer) = if let Some(delay) = timeout { - (delay.deadline(), Some(delay)) - } else if let Some(delay) = config.keep_alive_timer() { - (delay.deadline(), Some(delay)) - } else { - (config.now(), None) - }; + // start support io tasks + crate::rt::spawn(ReadTask::new(io.clone(), state.clone())); + crate::rt::spawn(WriteTask::new(io, state.clone())); Dispatcher { - call: CallState::Io, - upgrade: None, - inner: InnerDispatcher { - write_buf: BytesMut::with_capacity(WRITE_HW_BUFFER_SIZE), - req_payload: None, - res_payload: None, + call: CallState::None, + st: State::ReadRequest, + inner: DispatcherInner { + flags: Flags::empty(), error: None, - io: Some(io), - config, + payload: None, codec, - read_buf, - flags, + config, + state, + expire, peer_addr, - on_connect, - ka_expire, - ka_timer, + on_connect_data, + _t: PhantomData, }, } } } -impl Future for Dispatcher +impl Future for Dispatcher where - T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, X: Service, - X::Error: ResponseError, - U: Service), Response = ()>, - U::Error: fmt::Display, + X::Error: ResponseError + 'static, + U: Service, + U::Error: Error + fmt::Display + 'static, { type Output = Result<(), DispatchError>; - #[allow(clippy::cognitive_complexity)] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.as_mut().project(); - // handle upgrade request - if this.inner.flags.contains(Flags::UPGRADE) { - return this.upgrade.as_pin_mut().unwrap().poll(cx).map_err(|e| { - error!("Upgrade handler error: {}", e); - DispatchError::Upgrade - }); - } - - // shutdown process - if this.inner.flags.contains(Flags::SHUTDOWN) { - return this.inner.poll_shutdown(cx); - } - - // process incoming bytes stream - let mut not_completed = !this.inner.poll_read(cx); - this.inner.decode_payload(); - loop { - // process incoming bytes stream, but only if - // previous iteration didnt read whole buffer - if not_completed { - not_completed = !this.inner.poll_read(cx); - } - - let st = match this.call.project() { - CallStateProject::Service { mut fut } => { - loop { - // we have to loop because of read back-pressure, - // check Poll::Pending processing - match fut.poll(cx) { + match this.st { + State::Call => { + let next = match this.call.project() { + // handle SERVICE call + CallStateProject::Service { fut } => { + // we have to loop because of read back-pressure, + // check Poll::Pending processing + match fut.poll(cx) { + Poll::Ready(result) => match result { + Ok(res) => { + let (res, body) = res.into().into_parts(); + *this.st = this.inner.send_response(res, body) + } + Err(e) => { + *this.st = this.inner.handle_error(e, false) + } + }, + Poll::Pending => { + // we might need to read more data into a request payload + // (ie service future can wait for payload data) + if this.inner.poll_read_payload(cx) + != PollPayloadStatus::Updated + { + return Poll::Pending; + } + } + } + None + } + // handle EXPECT call + CallStateProject::Expect { fut } => match fut.poll(cx) { Poll::Ready(result) => match result { - Ok(res) => { - break this.inner.process_response(res.into())? + Ok(req) => { + this.inner.state.with_write_buf(|buf| { + buf.extend_from_slice( + b"HTTP/1.1 100 Continue\r\n\r\n", + ) + }); + Some(if this.inner.flags.contains(Flags::UPGRADE) { + // Handle UPGRADE request + CallState::Upgrade { + fut: this + .inner + .config + .upgrade + .as_ref() + .unwrap() + .call(( + req, + this.inner.state.clone(), + this.inner.codec.clone(), + )), + } + } else { + CallState::Service { + fut: this.inner.config.service.call(req), + } + }) } Err(e) => { - let res: Response = e.into(); - break this.inner.process_response( - res.map_body(|_, body| body.into_body()), - )?; + *this.st = this.inner.handle_error(e, true); + None } }, Poll::Pending => { - // if read back-pressure is enabled, we might need - // to read more data (ie service future can wait for payload data) - if this.inner.req_payload.is_some() && not_completed { - // read more from io stream - not_completed = !this.inner.poll_read(cx); - - // more payload chunks has been decoded - if this.inner.decode_payload() { - // restore consumed future - this = self.as_mut().project(); - fut = { - match this.call.project() { - CallStateProject::Service { fut } => fut, - _ => panic!(), - } - }; - continue; - } - } - break CallProcess::Pending; + // expect service call must resolve before + // we can do any more io processing. + // + // TODO: check keep-alive timer interaction + return Poll::Pending; } + }, + CallStateProject::Upgrade { fut } => { + return fut.poll(cx).map_err(|e| { + error!("Upgrade handler error: {}", e); + DispatchError::Upgrade(Box::new(e)) + }); } - } - } - // handle EXPECT call - CallStateProject::Expect { fut } => match fut.poll(cx) { - Poll::Ready(result) => match result { - Ok(req) => { - this.inner - .write_buf - .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n"); - CallProcess::Next(CallState::Service { - fut: this.inner.config.service.call(req), - }) - } - Err(e) => { - let res: Response = e.into(); - this.inner.process_response( - res.map_body(|_, body| body.into_body()), - )? - } - }, - Poll::Pending => { - // expect service call must resolve before - // we can do any more io processing. - // - // TODO: check keep-alive timer interaction - return Poll::Pending; - } - }, - CallStateProject::Io => CallProcess::Io, - }; - - let idle = match st { - CallProcess::Next(st) => { - // we have next call state, just proceed with it - this = self.as_mut().project(); - this.call.set(st); - continue; - } - CallProcess::Pending => { - // service response is in process, - // we just flush output and that is it - this.inner.poll_write(cx)?; - false - } - CallProcess::Io => { - // service call queue is empty, we can process next request - let write = if !this.inner.flags.contains(Flags::STARTED) { - PollWrite::AllowNext - } else { - this.inner.decode_payload(); - this.inner.poll_write(cx)? + CallStateProject::None => unreachable!(), }; - match write { - PollWrite::AllowNext => { - match this.inner.process_messages(CallProcess::Io)? { - CallProcess::Next(st) => { - this = self.as_mut().project(); - this.call.set(st); + + this = self.as_mut().project(); + if let Some(next) = next { + this.call.set(next); + } + } + State::ReadRequest => { + // stop dispatcher + if this.inner.state.is_dsp_stopped() { + log::trace!("dispatcher is instructed to stop"); + *this.st = State::Stop; + continue; + } + + // keep-alive timeout + if this.inner.state.is_keepalive() { + if !this.inner.flags.contains(Flags::STARTED) { + log::trace!("slow request timeout"); + let (req, body) = + Response::RequestTimeout().finish().into_parts(); + let _ = this.inner.send_response(req, body.into_body()); + this.inner.error = Some(DispatchError::SlowRequestTimeout); + } else { + log::trace!("keep-alive timeout, close connection"); + } + *this.st = State::Stop; + + continue; + } + + // decode incoming bytes stream + if this.inner.state.is_read_ready() { + match this.inner.state.decode_item(&this.inner.codec) { + Ok(Some((mut req, pl))) => { + log::trace!("http message is received: {:?}", req); + req.head_mut().peer_addr = this.inner.peer_addr; + + // configure request payload + let upgrade = match pl { + PayloadType::None => false, + PayloadType::Payload(decoder) => { + let (ps, pl) = Payload::create(false); + req.replace_payload(http::Payload::H1(pl)); + this.inner.payload = Some((decoder, ps)); + false + } + PayloadType::Stream(decoder) => { + if this.inner.config.upgrade.is_none() { + let (ps, pl) = Payload::create(false); + req.replace_payload(http::Payload::H1(pl)); + this.inner.payload = Some((decoder, ps)); + false + } else { + this.inner.flags.insert(Flags::UPGRADE); + true + } + } + }; + + // unregister slow-request timer + if !this.inner.flags.contains(Flags::STARTED) { + this.inner.flags.insert(Flags::STARTED); + this.inner.config.timer_h1.unregister( + this.inner.expire, + &this.inner.state, + ); + } + + // set on_connect data + if let Some(ref on_connect) = this.inner.on_connect_data + { + on_connect.set(&mut req.extensions_mut()); + } + + // call service + *this.st = State::Call; + this.call.set(if req.head().expect() { + // Handle `EXPECT: 100-Continue` header + CallState::Expect { + fut: this.inner.config.expect.call(req), + } + } else if upgrade { + log::trace!("initate upgrade handling"); + // Handle UPGRADE request + CallState::Upgrade { + fut: this + .inner + .config + .upgrade + .as_ref() + .unwrap() + .call(( + req, + this.inner.state.clone(), + this.inner.codec.clone(), + )), + } + } else { + // Handle normal requests + CallState::Service { + fut: this.inner.config.service.call(req), + } + }); + } + Ok(None) => { + // if connection is not keep-alive then disconnect + if this.inner.flags.contains(Flags::STARTED) + && !this.inner.flags.contains(Flags::KEEPALIVE) + { + *this.st = State::Stop; continue; } - CallProcess::Upgrade(fut) => { - this.upgrade.set(Some(fut)); - return self.poll(cx); - } - CallProcess::Io => true, - CallProcess::Pending => unreachable!(), + this.inner.state.dsp_read_more_data(cx.waker()); + return Poll::Pending; + } + Err(err) => { + // Malformed requests, respond with 400 + let (res, body) = + Response::BadRequest().finish().into_parts(); + this.inner.error = Some(DispatchError::Parse(err)); + *this.st = + this.inner.send_response(res, body.into_body()); + continue; } } - PollWrite::Pending => this.inner.res_payload.is_none(), - PollWrite::PendingResponse => { - this.inner.flags.contains(Flags::DISCONNECT) + } else { + // if connection is not keep-alive then disconnect + if this.inner.flags.contains(Flags::STARTED) + && !this.inner.flags.contains(Flags::KEEPALIVE) + { + *this.st = State::Stop; + continue; } + this.inner.state.dsp_register_task(cx.waker()); + return Poll::Pending; } } - CallProcess::Upgrade(fut) => { - this.upgrade.set(Some(fut)); - return self.poll(cx); - } - }; + // consume request's payload + State::ReadPayload => loop { + match this.inner.poll_read_payload(cx) { + PollPayloadStatus::Updated => continue, + PollPayloadStatus::Pending => return Poll::Pending, + PollPayloadStatus::Done => { + *this.st = { + this.inner.reset_keepalive(); + State::ReadRequest + } + } + PollPayloadStatus::Dropped => *this.st = State::Stop, + } + break; + }, + // send response body + State::SendPayload { ref mut body } => { + this.inner.poll_read_payload(cx); - // socket is closed and we are not processing any service responses - if this - .inner - .flags - .intersects(Flags::DISCONNECT | Flags::STOP_READING) - && idle - { - trace!("Shutdown connection (no more work) {:?}", this.inner.flags); - this.inner.flags.insert(Flags::SHUTDOWN); - } - // we dont have any parsed requests and output buffer is flushed - else if idle && this.inner.write_buf.is_empty() { - if let Some(err) = this.inner.error.take() { - trace!("Dispatcher error {:?}", err); - return Poll::Ready(Err(err)); + match body.poll_next_chunk(cx) { + Poll::Ready(item) => { + if let Some(st) = this.inner.send_payload(item) { + *this.st = st; + } + } + Poll::Pending => return Poll::Pending, + } } + // prepare to shutdown + State::Stop => { + this.inner.state.shutdown_io(); + this.inner.unregister_keepalive(); - // disconnect if keep-alive is not enabled - if this.inner.flags.contains(Flags::STARTED) - && !this.inner.flags.contains(Flags::KEEPALIVE) - { - trace!("Shutdown, keep-alive is not enabled"); - this.inner.flags.insert(Flags::SHUTDOWN); + // get io error + if this.inner.error.is_none() { + this.inner.error = + this.inner.state.take_io_error().map(DispatchError::Io); + } + + return Poll::Ready(if let Some(err) = this.inner.error.take() { + Err(err) + } else { + Ok(()) + }); } } - - // disconnect if shutdown - return if this.inner.flags.contains(Flags::SHUTDOWN) { - this.inner.poll_shutdown(cx) - } else { - if this.inner.poll_flush(cx)? { - // some data has been written to io stream - this = self.as_mut().project(); - continue; - } - - // keep-alive book-keeping - if this.inner.ka_timer.is_some() && this.inner.poll_keepalive(cx, idle) { - this.inner.poll_shutdown(cx) - } else { - Poll::Pending - } - }; } } } -impl InnerDispatcher +impl DispatcherInner where - T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, - X: Service, - X::Error: ResponseError, - U: Service), Response = ()>, - U::Error: fmt::Display, { - /// shutdown process - fn poll_shutdown( - &mut self, - cx: &mut Context<'_>, - ) -> Poll> { - // we can not do anything here - if self.flags.contains(Flags::DISCONNECT) { - return Poll::Ready(Ok(())); - } - - if !self.flags.contains(Flags::SHUTDOWN_IO) { - self.poll_flush(cx)?; - - if self.write_buf.is_empty() { - ready!(Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx)?); - self.flags.insert(Flags::SHUTDOWN_IO); - } - } - - // read until 0 or err - let mut buf = [0u8; 512]; - while let Poll::Ready(res) = - Pin::new(self.io.as_mut().unwrap()).poll_read(cx, &mut buf) - { - match res { - Err(_) | Ok(0) => return Poll::Ready(Ok(())), - _ => (), - } - } - - // shutdown timeout - if self.ka_timer.is_none() { - if self.flags.contains(Flags::SHUTDOWN_TM) { - // shutdown timeout is not enabled - Poll::Pending - } else { - self.flags.insert(Flags::SHUTDOWN_TM); - if let Some(interval) = self.config.client_disconnect_timer() { - trace!("Start shutdown timer for {:?}", interval); - self.ka_timer = Some(delay_until(interval)); - let _ = Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx); - } - Poll::Pending - } - } else { - let mut timer = self.ka_timer.as_mut().unwrap(); - - // configure timer - if !self.flags.contains(Flags::SHUTDOWN_TM) { - if let Some(interval) = self.config.client_disconnect_timer() { - self.flags.insert(Flags::SHUTDOWN_TM); - timer.reset(interval); - } else { - let _ = self.ka_timer.take(); - return Poll::Pending; - } - } - - match Pin::new(&mut timer).poll(cx) { - Poll::Ready(_) => { - // if we get timeout during shutdown, drop connection - Poll::Ready(Err(DispatchError::DisconnectTimeout)) - } - _ => Poll::Pending, - } + fn unregister_keepalive(&mut self) { + if self.flags.contains(Flags::KEEPALIVE) { + self.config.timer_h1.unregister(self.expire, &self.state); } } - /// Flush stream - fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result { - let len = self.write_buf.len(); - if len == 0 { - return Ok(false); + fn reset_keepalive(&mut self) { + // re-register keep-alive + if self.flags.contains(Flags::KEEPALIVE) { + let expire = + self.config.timer_h1.now() + Duration::from_secs(self.config.keep_alive); + self.config + .timer_h1 + .register(expire, self.expire, &self.state); + self.expire = expire; + self.state.reset_keepalive(); } - - let mut written = 0; - let mut io = self.io.as_mut().unwrap(); - - while written < len { - match Pin::new(&mut io).poll_write(cx, &self.write_buf[written..]) { - Poll::Ready(Ok(n)) => { - if n == 0 { - trace!("Disconnected during flush, written {}", written); - return Err(DispatchError::Io(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))); - } else { - written += n - } - } - Poll::Pending => break, - Poll::Ready(Err(e)) => { - trace!("Error during flush: {}", e); - return Err(DispatchError::Io(e)); - } - } - } - if written == len { - // flushed whole buffer, we dont need to reallocate - unsafe { self.write_buf.set_len(0) } - } else { - self.write_buf.advance(written); - } - Ok(written != 0) } - fn send_response( - &mut self, - msg: Response<()>, - body: ResponseBody, - ) -> Result { + fn handle_error(&mut self, err: E, critical: bool) -> State + where + E: ResponseError + 'static, + { + let res: Response = (&err).into(); + let (res, body) = res.into_parts(); + let state = self.send_response(res, body.into_body()); + + // check if we can continue after error + if critical || self.payload.take().is_some() { + self.error = Some(DispatchError::Service(Box::new(err))); + State::Stop + } else { + state + } + } + + fn send_response(&mut self, msg: Response<()>, body: ResponseBody) -> State { trace!("Sending response: {:?} body: {:?}", msg, body.size()); // we dont need to process responses if socket is disconnected // but we still want to handle requests with app service - // so we skip response processing for disconnected connection - if !self.flags.contains(Flags::DISCONNECT) { - self.codec - .encode(Message::Item((msg, body.size())), &mut self.write_buf) + // so we skip response processing for droppped connection + if !self.state.is_io_err() { + let result = self + .state + .write_item(Message::Item((msg, body.size())), &self.codec) .map_err(|err| { - if let Some(mut payload) = self.req_payload.take() { - payload.set_error(PayloadError::Incomplete(None)); + if let Some(mut payload) = self.payload.take() { + payload.1.set_error(PayloadError::Incomplete(None)); } - DispatchError::Io(err) - })?; + err + }); - self.flags.set(Flags::KEEPALIVE, self.codec.keepalive()); - - match body.size() { - BodySize::None | BodySize::Empty => { - // update keep-alive timer - if self.flags.contains(Flags::HAS_KEEPALIVE) { - if let Some(expire) = self.config.keep_alive_expire() { - self.ka_expire = expire; - } - } - Ok(true) - } - _ => { - self.res_payload = Some(body); - Ok(false) - } - } - } else { - Ok(false) - } - } - - fn poll_write(&mut self, cx: &mut Context<'_>) -> Result { - while let Some(ref mut stream) = self.res_payload { - let len = self.write_buf.len(); - - if len < BUFFER_SIZE { - // increase write buffer - let remaining = self.write_buf.capacity() - len; - if remaining < WRITE_LW_BUFFER_SIZE { - self.write_buf.reserve(BUFFER_SIZE - remaining); - } - - match stream.poll_next_chunk(cx) { - Poll::Ready(Some(Ok(item))) => { - trace!("Got response chunk: {:?}", item.len()); - self.codec - .encode(Message::Chunk(Some(item)), &mut self.write_buf)?; - } - Poll::Ready(None) => { - trace!("Response payload eof"); - self.codec - .encode(Message::Chunk(None), &mut self.write_buf)?; - self.res_payload = None; - - // update keep-alive timer - if self.flags.contains(Flags::HAS_KEEPALIVE) { - if let Some(expire) = self.config.keep_alive_expire() { - self.ka_expire = expire; - } - } - break; - } - Poll::Ready(Some(Err(e))) => { - trace!("Error during response body poll: {:?}", e); - return Err(DispatchError::Unknown); - } - Poll::Pending => { - // response payload stream is not ready we can only flush - return Ok(PollWrite::PendingResponse); - } - } + if result.is_err() { + State::Stop } else { - // write buffer is full, we need to flush - return Ok(PollWrite::PendingResponse); - } - } + self.flags.set(Flags::KEEPALIVE, self.codec.keepalive()); - // we have enought space in write bffer - if self.write_buf.len() < BUFFER_SIZE { - Ok(PollWrite::AllowNext) + match body.size() { + BodySize::None | BodySize::Empty => { + if self.error.is_some() { + State::Stop + } else if self.payload.is_some() { + State::ReadPayload + } else { + self.reset_keepalive(); + State::ReadRequest + } + } + _ => State::SendPayload { body }, + } + } } else { - Ok(PollWrite::Pending) + State::Stop } } - /// Read data from io stream - fn poll_read(&mut self, cx: &mut Context<'_>) -> bool { - let mut completed = false; - - // read socket data into a buf - if !self - .flags - .intersects(Flags::DISCONNECT | Flags::STOP_READING) - { - // drain until request payload is consumed and requires more data (backpressure off) - if !self - .req_payload - .as_ref() - .map(|info| info.need_read(cx) == PayloadStatus::Read) - .unwrap_or(true) - { - return false; - } - - // read data from socket - let io = self.io.as_mut().unwrap(); - let buf = &mut self.read_buf; - - // increase read buffer size - let remaining = buf.capacity() - buf.len(); - if remaining < READ_LW_BUFFER_SIZE { - buf.reserve(BUFFER_SIZE); - } - - while buf.len() < BUFFER_SIZE { - match Pin::new(&mut *io).poll_read_buf(cx, buf) { - Poll::Pending => { - completed = true; - break; - } - Poll::Ready(Ok(n)) => { - if n == 0 { - trace!( - "Disconnected during read, buffer size {}", - buf.len() - ); - self.flags.insert(Flags::DISCONNECT); - break; - } - self.flags.remove(Flags::READ_EOF); - } - Poll::Ready(Err(e)) => { - trace!("Error during read: {:?}", e); - self.flags.insert(Flags::DISCONNECT); - self.error = Some(DispatchError::Io(e)); - break; - } - } - } - } - - completed - } - - fn internal_error(&mut self, msg: &'static str) -> DispatcherMessage { - error!("{}", msg); - self.flags.insert(Flags::DISCONNECT | Flags::READ_EOF); - self.error = Some(DispatchError::InternalError); - DispatcherMessage::Error(Response::InternalServerError().finish().drop_body()) - } - - fn decode_error(&mut self, e: ParseError) -> DispatcherMessage { - // error during request decoding - if let Some(mut payload) = self.req_payload.take() { - payload.set_error(PayloadError::EncodingCorrupted); - } - - // Malformed requests should be responded with 400 - self.flags.insert(Flags::STOP_READING); - self.read_buf.clear(); - self.error = Some(e.into()); - DispatcherMessage::Error(Response::BadRequest().finish().drop_body()) - } - - fn decode_payload(&mut self) -> bool { - if self.flags.contains(Flags::READ_EOF) - || self.req_payload.is_none() - || self.read_buf.is_empty() - { - return false; - } - - let mut updated = false; - loop { - match self.codec.decode(&mut self.read_buf) { - Ok(Some(msg)) => match msg { - Message::Chunk(chunk) => { - updated = true; - if let Some(ref mut payload) = self.req_payload { - if let Some(chunk) = chunk { - payload.feed_data(chunk); - } else { - payload.feed_eof(); - self.req_payload = None; - } - } else { - self.internal_error( - "Internal server error: unexpected payload chunk", - ); - break; - } - } - Message::Item(_) => { - self.internal_error( - "Internal server error: unexpected http message", - ); - break; - } - }, - Ok(None) => { - self.flags.insert(Flags::READ_EOF); - break; - } - Err(e) => { - self.decode_error(e); - break; - } - } - } - - updated - } - - fn decode_message(&mut self) -> Option { - if self.flags.contains(Flags::READ_EOF) || self.read_buf.is_empty() { - return None; - } - - match self.codec.decode(&mut self.read_buf) { - Ok(Some(msg)) => { - self.flags.insert(Flags::STARTED); - - match msg { - Message::Item(mut req) => { - let pl = self.codec.message_type(); - req.head_mut().peer_addr = self.peer_addr; - - // set on_connect data - if let Some(ref on_connect) = self.on_connect { - on_connect.set(&mut req.extensions_mut()); - } - - // handle upgrade request - if pl == MessageType::Stream && self.config.upgrade.is_some() { - self.flags.insert(Flags::STOP_READING); - Some(DispatcherMessage::Upgrade(req)) - } else { - // handle request with payload - if pl == MessageType::Payload || pl == MessageType::Stream { - let (ps, pl) = Payload::create(false); - let (req1, _) = - req.replace_payload(crate::http::Payload::H1(pl)); - req = req1; - self.req_payload = Some(ps); - } - - Some(DispatcherMessage::Request(req)) - } - } - Message::Chunk(_) => Some(self.internal_error( - "Internal server error: unexpected payload chunk", - )), - } - } - Ok(None) => { - self.flags.insert(Flags::READ_EOF); - None - } - Err(e) => Some(self.decode_error(e)), - } - } - - /// keep-alive timer - fn poll_keepalive(&mut self, cx: &mut Context<'_>, idle: bool) -> bool { - let ka_timer = self.ka_timer.as_mut().unwrap(); - // do nothing for disconnected or upgrade socket or if keep-alive timer is disabled - if self.flags.contains(Flags::DISCONNECT) { - return false; - } - // slow request timeout - else if !self.flags.contains(Flags::STARTED) { - if Pin::new(ka_timer).poll(cx).is_ready() { - // timeout on first request (slow request) return 408 - trace!("Slow request timeout"); - let _ = self.send_response( - Response::RequestTimeout().finish().drop_body(), - ResponseBody::Other(Body::Empty), - ); - self.flags.insert(Flags::STARTED | Flags::SHUTDOWN); - return true; - } - } - // normal keep-alive, but only if we are not processing any requests - else if idle { - // keep-alive timer - if Pin::new(&mut *ka_timer).poll(cx).is_ready() { - if ka_timer.deadline() >= self.ka_expire { - // check for any outstanding tasks - if self.write_buf.is_empty() { - trace!("Keep-alive timeout, close connection"); - self.flags.insert(Flags::SHUTDOWN); - return true; - } else if let Some(dl) = self.config.keep_alive_expire() { - // extend keep-alive timer - ka_timer.reset(dl); - } + fn send_payload( + &mut self, + item: Option>>, + ) -> Option> { + match item { + Some(Ok(item)) => { + trace!("Got response chunk: {:?}", item.len()); + if let Err(err) = self + .state + .write_item(Message::Chunk(Some(item)), &self.codec) + { + self.error = Some(DispatchError::Encode(err)); + Some(State::Stop) } else { - ka_timer.reset(self.ka_expire); + None } - let _ = Pin::new(ka_timer).poll(cx); + } + None => { + trace!("Response payload eof"); + if let Err(err) = + self.state.write_item(Message::Chunk(None), &self.codec) + { + self.error = Some(DispatchError::Encode(err)); + Some(State::Stop) + } else if self.payload.is_some() { + Some(State::ReadPayload) + } else { + self.reset_keepalive(); + Some(State::ReadRequest) + } + } + Some(Err(e)) => { + trace!("Error during response body poll: {:?}", e); + self.error = Some(DispatchError::ResponsePayload(e)); + Some(State::Stop) } } - false } - fn process_response( - &mut self, - res: Response, - ) -> Result, DispatchError> { - let (res, body) = res.replace_body(()); - if self.send_response(res, body)? { - // response does not have body, so we can process next request - self.process_messages(CallProcess::Next(CallState::Io)) + /// Process request's payload + fn poll_read_payload(&mut self, cx: &mut Context<'_>) -> PollPayloadStatus { + // check if payload data is required + if let Some(ref mut payload) = self.payload { + match payload.1.poll_data_required(cx) { + PayloadStatus::Read => { + // read request payload + let mut updated = false; + loop { + let item = self.state.with_read_buf(|buf| payload.0.decode(buf)); + match item { + Ok(Some(PayloadItem::Chunk(chunk))) => { + updated = true; + payload.1.feed_data(chunk); + } + Ok(Some(PayloadItem::Eof)) => { + payload.1.feed_eof(); + self.payload = None; + if !updated { + return PollPayloadStatus::Done; + } + break; + } + Ok(None) => { + self.state.dsp_read_more_data(cx.waker()); + break; + } + Err(e) => { + payload.1.set_error(PayloadError::EncodingCorrupted); + self.payload = None; + self.error = Some(DispatchError::Parse(e)); + return PollPayloadStatus::Dropped; + } + } + } + if updated { + PollPayloadStatus::Updated + } else { + PollPayloadStatus::Pending + } + } + PayloadStatus::Pause => PollPayloadStatus::Pending, + PayloadStatus::Dropped => { + // service call is not interested in payload + // wait until future completes and then close + // connection + self.payload = None; + self.error = Some(DispatchError::PayloadIsNotConsumed); + PollPayloadStatus::Dropped + } + } } else { - Ok(CallProcess::Next(CallState::Io)) + PollPayloadStatus::Done } } - - fn process_messages( - &mut self, - io: CallProcess, - ) -> Result, DispatchError> { - while let Some(msg) = self.decode_message() { - return match msg { - DispatcherMessage::Request(req) => { - if self.req_payload.is_some() { - self.decode_payload(); - } - - // Handle `EXPECT: 100-Continue` header - Ok(CallProcess::Next(if req.head().expect() { - CallState::Expect { - fut: self.config.expect.call(req), - } - } else { - CallState::Service { - fut: self.config.service.call(req), - } - })) - } - // switch to upgrade handler - DispatcherMessage::Upgrade(req) => { - self.flags.insert(Flags::UPGRADE); - let mut parts = FramedParts::with_read_buf( - self.io.take().unwrap(), - mem::take(&mut self.codec), - mem::take(&mut self.read_buf), - ); - parts.write_buf = mem::take(&mut self.write_buf); - let framed = Framed::from_parts(parts); - - Ok(CallProcess::Upgrade( - self.config.upgrade.as_ref().unwrap().call((req, framed)), - )) - } - DispatcherMessage::Error(res) => { - if self.send_response(res, ResponseBody::Other(Body::Empty))? { - // response does not have body, so we can process next request - continue; - } else { - return Ok(io); - } - } - }; - } - Ok(io) - } -} - -#[cfg(test)] -mod tests { - use bytes::Bytes; - use futures::future::{lazy, ok, Future, FutureExt}; - use futures::StreamExt; - use rand::Rng; - use std::rc::Rc; - use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; - use std::sync::Arc; - use std::time::Duration; - - use super::*; - use crate::http::config::{DispatcherConfig, ServiceConfig}; - use crate::http::h1::{ClientCodec, ExpectHandler, UpgradeHandler}; - use crate::http::{body, Request, ResponseHead, StatusCode}; - use crate::rt::time::delay_for; - use crate::service::IntoService; - use crate::testing::Io; - - /// Create http/1 dispatcher. - pub(crate) fn h1( - stream: Io, - service: F, - ) -> Dispatcher> - where - F: IntoService, - S: Service, - S::Error: ResponseError, - S::Response: Into>, - B: MessageBody, - { - Dispatcher::new( - Rc::new(DispatcherConfig::new( - ServiceConfig::default(), - service.into_service(), - ExpectHandler, - None, - )), - stream, - None, - None, - ) - } - - pub(crate) fn spawn_h1(stream: Io, service: F) - where - F: IntoService, - S: Service + 'static, - S::Error: ResponseError, - S::Response: Into>, - B: MessageBody + 'static, - { - crate::rt::spawn( - Dispatcher::>::new( - Rc::new(DispatcherConfig::new( - ServiceConfig::default(), - service.into_service(), - ExpectHandler, - None, - )), - stream, - None, - None, - ), - ); - } - - fn load(decoder: &mut ClientCodec, buf: &mut BytesMut) -> ResponseHead { - decoder.decode(buf).unwrap().unwrap() - } - - #[ntex_rt::test] - async fn test_req_parse_err() { - let (client, server) = Io::create(); - client.remote_buffer_cap(1024); - client.write("GET /test HTTP/1\r\n\r\n"); - - let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish())); - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); - assert!(h1.inner.flags.contains(Flags::SHUTDOWN)); - client - .local_buffer(|buf| assert_eq!(&buf[..26], b"HTTP/1.1 400 Bad Request\r\n")); - - client.close().await; - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); - assert!(h1.inner.flags.contains(Flags::SHUTDOWN_IO)); - } - - #[ntex_rt::test] - async fn test_pipeline() { - let (client, server) = Io::create(); - client.remote_buffer_cap(4096); - let mut decoder = ClientCodec::default(); - spawn_h1(server, |_| ok::<_, io::Error>(Response::Ok().finish())); - - client.write("GET /test HTTP/1.1\r\n\r\n"); - - let mut buf = client.read().await.unwrap(); - assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(!client.is_server_dropped()); - - client.write("GET /test HTTP/1.1\r\n\r\n"); - client.write("GET /test HTTP/1.1\r\n\r\n"); - - let mut buf = client.read().await.unwrap(); - assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(decoder.decode(&mut buf).unwrap().is_none()); - assert!(!client.is_server_dropped()); - - client.close().await; - assert!(client.is_server_dropped()); - } - - #[ntex_rt::test] - async fn test_pipeline_with_payload() { - let (client, server) = Io::create(); - client.remote_buffer_cap(4096); - let mut decoder = ClientCodec::default(); - spawn_h1(server, |mut req: Request| async move { - let mut p = req.take_payload(); - while let Some(_) = p.next().await {} - Ok::<_, io::Error>(Response::Ok().finish()) - }); - - client.write("GET /test HTTP/1.1\r\ncontent-length: 5\r\n\r\n"); - delay_for(Duration::from_millis(50)).await; - client.write("xxxxx"); - - let mut buf = client.read().await.unwrap(); - assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(!client.is_server_dropped()); - - client.write("GET /test HTTP/1.1\r\n\r\n"); - - let mut buf = client.read().await.unwrap(); - assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(decoder.decode(&mut buf).unwrap().is_none()); - assert!(!client.is_server_dropped()); - - client.close().await; - assert!(client.is_server_dropped()); - } - - #[ntex_rt::test] - async fn test_pipeline_with_delay() { - let (client, server) = Io::create(); - client.remote_buffer_cap(4096); - let mut decoder = ClientCodec::default(); - spawn_h1(server, |_| async { - delay_for(Duration::from_millis(100)).await; - Ok::<_, io::Error>(Response::Ok().finish()) - }); - - client.write("GET /test HTTP/1.1\r\n\r\n"); - - let mut buf = client.read().await.unwrap(); - assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(!client.is_server_dropped()); - - client.write("GET /test HTTP/1.1\r\n\r\n"); - client.write("GET /test HTTP/1.1\r\n\r\n"); - delay_for(Duration::from_millis(50)).await; - client.write("GET /test HTTP/1.1\r\n\r\n"); - - let mut buf = client.read().await.unwrap(); - assert!(load(&mut decoder, &mut buf).status.is_success()); - - let mut buf = client.read().await.unwrap(); - assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(decoder.decode(&mut buf).unwrap().is_none()); - assert!(!client.is_server_dropped()); - - buf.extend(client.read().await.unwrap()); - assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(decoder.decode(&mut buf).unwrap().is_none()); - assert!(!client.is_server_dropped()); - - client.close().await; - assert!(client.is_server_dropped()); - } - - #[ntex_rt::test] - /// if socket is disconnected - /// h1 dispatcher still processes all incoming requests - /// but it does not write any data to socket - async fn test_write_disconnected() { - let num = Arc::new(AtomicUsize::new(0)); - let num2 = num.clone(); - - let (client, server) = Io::create(); - spawn_h1(server, move |_| { - num2.fetch_add(1, Ordering::Relaxed); - ok::<_, io::Error>(Response::Ok().finish()) - }); - - client.remote_buffer_cap(1024); - client.write("GET /test HTTP/1.1\r\n\r\n"); - client.write("GET /test HTTP/1.1\r\n\r\n"); - client.write("GET /test HTTP/1.1\r\n\r\n"); - client.close().await; - assert!(client.is_server_dropped()); - assert!(client.read_any().is_empty()); - - // all request must be handled - assert_eq!(num.load(Ordering::Relaxed), 3); - } - - #[ntex_rt::test] - async fn test_read_large_message() { - let (client, server) = Io::create(); - client.remote_buffer_cap(4096); - - let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish())); - let mut decoder = ClientCodec::default(); - - let data = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(70_000) - .map(char::from) - .collect::(); - client.write("GET /test HTTP/1.1\r\nContent-Length: "); - client.write(data); - - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); - assert!(h1.inner.flags.contains(Flags::SHUTDOWN)); - - let mut buf = client.read().await.unwrap(); - assert_eq!(load(&mut decoder, &mut buf).status, StatusCode::BAD_REQUEST); - } - - #[ntex_rt::test] - async fn test_read_backpressure() { - let mark = Arc::new(AtomicBool::new(false)); - let mark2 = mark.clone(); - - let (client, server) = Io::create(); - client.remote_buffer_cap(4096); - spawn_h1(server, move |mut req: Request| { - let m = mark2.clone(); - async move { - // read one chunk - let mut pl = req.take_payload(); - let _ = pl.next().await.unwrap().unwrap(); - m.store(true, Ordering::Relaxed); - // sleep - delay_for(Duration::from_secs(999_999)).await; - Ok::<_, io::Error>(Response::Ok().finish()) - } - }); - - client.write("GET /test HTTP/1.1\r\nContent-Length: 1048576\r\n\r\n"); - delay_for(Duration::from_millis(50)).await; - - // buf must be consumed - assert_eq!(client.remote_buffer(|buf| buf.len()), 0); - - // io should be drained only by no more than MAX_BUFFER_SIZE - let random_bytes: Vec = - (0..1_048_576).map(|_| rand::random::()).collect(); - client.write(random_bytes); - - delay_for(Duration::from_millis(50)).await; - assert!(client.remote_buffer(|buf| buf.len()) > 1_048_576 - BUFFER_SIZE * 3); - assert!(mark.load(Ordering::Relaxed)); - } - - #[ntex_rt::test] - async fn test_write_backpressure() { - let num = Arc::new(AtomicUsize::new(0)); - let num2 = num.clone(); - - struct Stream(Arc); - - impl body::MessageBody for Stream { - fn size(&self) -> body::BodySize { - body::BodySize::Stream - } - fn poll_next_chunk( - &mut self, - _: &mut Context<'_>, - ) -> Poll>>> { - let data = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(65_536) - .map(char::from) - .collect::(); - self.0.fetch_add(data.len(), Ordering::Relaxed); - - Poll::Ready(Some(Ok(Bytes::from(data)))) - } - } - - let (client, server) = Io::create(); - let mut h1 = h1(server, move |_| { - let n = num2.clone(); - async move { Ok::<_, io::Error>(Response::Ok().message_body(Stream(n.clone()))) } - .boxed_local() - }); - - // do not allow to write to socket - client.remote_buffer_cap(0); - client.write("GET /test HTTP/1.1\r\n\r\n"); - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); - - // buf must be consumed - assert_eq!(client.remote_buffer(|buf| buf.len()), 0); - - // amount of generated data - assert_eq!(num.load(Ordering::Relaxed), 65_536); - - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); - assert_eq!(num.load(Ordering::Relaxed), 65_536); - // response message + chunking encoding - assert_eq!(h1.inner.write_buf.len(), 65629); - - client.remote_buffer_cap(65536); - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); - assert_eq!(num.load(Ordering::Relaxed), 65_536 * 2); - } - - #[ntex_rt::test] - async fn test_disconnect_during_response_body_pending() { - struct Stream(bool); - - impl body::MessageBody for Stream { - fn size(&self) -> body::BodySize { - body::BodySize::Sized(2048) - } - fn poll_next_chunk( - &mut self, - _: &mut Context<'_>, - ) -> Poll>>> { - if self.0 { - Poll::Pending - } else { - self.0 = true; - let data = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(1024) - .map(char::from) - .collect::(); - Poll::Ready(Some(Ok(Bytes::from(data)))) - } - } - } - - let (client, server) = Io::create(); - client.remote_buffer_cap(4096); - let mut h1 = h1(server, |_| { - ok::<_, io::Error>(Response::Ok().message_body(Stream(false))) - }); - - client.write("GET /test HTTP/1.1\r\n\r\n"); - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); - - // buf must be consumed - assert_eq!(client.remote_buffer(|buf| buf.len()), 0); - - let mut decoder = ClientCodec::default(); - let mut buf = client.read().await.unwrap(); - assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); - - client.close().await; - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); - } } diff --git a/ntex/src/http/h1/encoder.rs b/ntex/src/http/h1/encoder.rs index ea58ac7f..62cf6830 100644 --- a/ntex/src/http/h1/encoder.rs +++ b/ntex/src/http/h1/encoder.rs @@ -1,7 +1,6 @@ use std::io::Write; use std::marker::PhantomData; -use std::ptr::copy_nonoverlapping; -use std::{cmp, io, mem, ptr, slice}; +use std::{cell::Cell, cmp, io, mem, ptr, ptr::copy_nonoverlapping, slice}; use bytes::{BufMut, BytesMut}; @@ -18,7 +17,7 @@ const AVERAGE_HEADER_SIZE: usize = 30; #[derive(Debug)] pub(super) struct MessageEncoder { pub(super) length: BodySize, - pub(super) te: TransferEncoding, + pub(super) te: Cell, _t: PhantomData, } @@ -26,7 +25,17 @@ impl Default for MessageEncoder { fn default() -> Self { MessageEncoder { length: BodySize::None, - te: TransferEncoding::empty(), + te: Cell::new(TransferEncoding::empty()), + _t: PhantomData, + } + } +} + +impl Clone for MessageEncoder { + fn clone(&self) -> Self { + MessageEncoder { + length: self.length, + te: self.te.clone(), _t: PhantomData, } } @@ -41,10 +50,10 @@ pub(super) trait MessageType: Sized { fn chunked(&self) -> bool; - fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>; + fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()>; fn encode_headers( - &mut self, + &self, dst: &mut BytesMut, version: Version, mut length: BodySize, @@ -208,7 +217,7 @@ impl MessageType for Response<()> { None } - fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { + fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()> { let head = self.head(); let reason = head.reason().as_bytes(); dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len()); @@ -237,7 +246,7 @@ impl MessageType for RequestHeadType { self.extra_headers() } - fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> { + fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()> { let head = self.as_ref(); dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE); write!( @@ -264,20 +273,26 @@ impl MessageType for RequestHeadType { impl MessageEncoder { /// Encode message pub(super) fn encode_chunk( - &mut self, + &self, msg: &[u8], buf: &mut BytesMut, ) -> io::Result { - self.te.encode(msg, buf) + let mut te = self.te.get(); + let result = te.encode(msg, buf); + self.te.set(te); + result } /// Encode eof - pub(super) fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { - self.te.encode_eof(buf) + pub(super) fn encode_eof(&self, buf: &mut BytesMut) -> io::Result<()> { + let mut te = self.te.get(); + let result = te.encode_eof(buf); + self.te.set(te); + result } pub(super) fn encode( - &mut self, + &self, dst: &mut BytesMut, message: &mut T, head: bool, @@ -289,7 +304,7 @@ impl MessageEncoder { ) -> io::Result<()> { // transfer encoding if !head { - self.te = match length { + self.te.set(match length { BodySize::Empty => TransferEncoding::empty(), BodySize::Sized(len) => TransferEncoding::length(len), BodySize::Stream => { @@ -300,9 +315,9 @@ impl MessageEncoder { } } BodySize::None => TransferEncoding::empty(), - }; + }); } else { - self.te = TransferEncoding::empty(); + self.te.set(TransferEncoding::empty()); } message.encode_status(dst)?; @@ -311,12 +326,12 @@ impl MessageEncoder { } /// Encoders to handle different Transfer-Encodings. -#[derive(Debug)] +#[derive(Debug, Copy, Clone)] pub(super) struct TransferEncoding { kind: TransferEncodingKind, } -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone, Copy)] enum TransferEncodingKind { /// An Encoder for when Transfer-Encoding includes `chunked`. Chunked(bool), @@ -368,14 +383,15 @@ impl TransferEncoding { buf.extend_from_slice(msg); Ok(eof) } - TransferEncodingKind::Chunked(ref mut eof) => { - if *eof { + TransferEncodingKind::Chunked(eof) => { + if eof { return Ok(true); } - if msg.is_empty() { - *eof = true; + let result = if msg.is_empty() { buf.extend_from_slice(b"0\r\n\r\n"); + self.kind = TransferEncodingKind::Chunked(true); + true } else { writeln!(helpers::Writer(buf), "{:X}\r", msg.len()) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; @@ -383,20 +399,22 @@ impl TransferEncoding { buf.reserve(msg.len() + 2); buf.extend_from_slice(msg); buf.extend_from_slice(b"\r\n"); - } - Ok(*eof) + false + }; + Ok(result) } - TransferEncodingKind::Length(ref mut remaining) => { - if *remaining > 0 { + TransferEncodingKind::Length(mut remaining) => { + if remaining > 0 { if msg.is_empty() { - return Ok(*remaining == 0); + return Ok(remaining == 0); } - let len = cmp::min(*remaining, msg.len() as u64); + let len = cmp::min(remaining, msg.len() as u64); buf.extend_from_slice(&msg[..len as usize]); - *remaining -= len as u64; - Ok(*remaining == 0) + remaining -= len as u64; + self.kind = TransferEncodingKind::Length(remaining); + Ok(remaining == 0) } else { Ok(true) } @@ -416,10 +434,10 @@ impl TransferEncoding { Ok(()) } } - TransferEncodingKind::Chunked(ref mut eof) => { - if !*eof { - *eof = true; + TransferEncodingKind::Chunked(eof) => { + if !eof { buf.extend_from_slice(b"0\r\n\r\n"); + self.kind = TransferEncodingKind::Chunked(true); } Ok(()) } @@ -614,7 +632,7 @@ mod tests { ); extra_headers.insert(DATE, HeaderValue::from_static("date")); - let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers)); + let head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers)); let _ = head.encode_headers( &mut bytes, diff --git a/ntex/src/http/h1/mod.rs b/ntex/src/http/h1/mod.rs index c7b6c880..8b8a63b5 100644 --- a/ntex/src/http/h1/mod.rs +++ b/ntex/src/http/h1/mod.rs @@ -13,6 +13,7 @@ mod upgrade; pub use self::client::{ClientCodec, ClientPayloadCodec}; pub use self::codec::Codec; +pub use self::decoder::{PayloadDecoder, PayloadItem, PayloadType}; pub use self::expect::ExpectHandler; pub use self::payload::Payload; pub use self::service::{H1Service, H1ServiceHandler}; @@ -54,33 +55,3 @@ pub(crate) fn reserve_readbuf(src: &mut BytesMut) { src.reserve(HW - cap); } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::http::request::Request; - - impl Message { - pub fn message(self) -> Request { - match self { - Message::Item(req) => req, - _ => panic!("error"), - } - } - - pub fn chunk(self) -> Bytes { - match self { - Message::Chunk(Some(data)) => data, - _ => panic!("error"), - } - } - - pub fn eof(self) -> bool { - match self { - Message::Chunk(None) => true, - Message::Chunk(Some(_)) => false, - _ => panic!("error"), - } - } - } -} diff --git a/ntex/src/http/h1/payload.rs b/ntex/src/http/h1/payload.rs index 42457880..5962a7e3 100644 --- a/ntex/src/http/h1/payload.rs +++ b/ntex/src/http/h1/payload.rs @@ -119,8 +119,8 @@ impl PayloadSender { } } - pub(super) fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus { - // we check need_read only if Payload (other side) is alive, + pub(super) fn poll_data_required(&self, cx: &mut Context<'_>) -> PayloadStatus { + // we check only if Payload (other side) is alive, // otherwise always return true (consume payload) if let Some(shared) = self.inner.upgrade() { if shared.borrow().need_read { diff --git a/ntex/src/http/h1/service.rs b/ntex/src/http/h1/service.rs index 74efb0d6..bbcd701e 100644 --- a/ntex/src/http/h1/service.rs +++ b/ntex/src/http/h1/service.rs @@ -1,11 +1,11 @@ -use std::marker::PhantomData; -use std::rc::Rc; -use std::task::{Context, Poll}; -use std::{fmt, net}; +use std::{ + error::Error, fmt, marker::PhantomData, net, rc::Rc, task::Context, task::Poll, +}; use futures::future::{ok, FutureExt, LocalBoxFuture}; -use crate::codec::{AsyncRead, AsyncWrite, Framed}; +use crate::codec::{AsyncRead, AsyncWrite}; +use crate::framed::State as IoState; use crate::http::body::MessageBody; use crate::http::config::{DispatcherConfig, ServiceConfig}; use crate::http::error::{DispatchError, ResponseError}; @@ -34,7 +34,7 @@ pub struct H1Service> { impl H1Service where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, @@ -59,21 +59,17 @@ where impl H1Service where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into>, S::Future: 'static, B: MessageBody, X: ServiceFactory, - X::Error: ResponseError, + X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, - U: ServiceFactory< - Config = (), - Request = (Request, Framed), - Response = (), - >, - U::Error: fmt::Display + ResponseError, + U: ServiceFactory, + U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, { @@ -105,21 +101,21 @@ mod openssl { impl H1Service, S, B, X, U> where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into>, S::Future: 'static, B: MessageBody, X: ServiceFactory, - X::Error: ResponseError, + X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, Framed, Codec>), + Request = (Request, IoState, Codec), Response = (), >, - U::Error: fmt::Display + ResponseError, + U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, { @@ -136,7 +132,7 @@ mod openssl { > { pipeline_factory( Acceptor::new(acceptor) - .timeout(self.handshake_timeout) + .timeout((self.handshake_timeout as u64) * 1000) .map_err(SslError::Ssl) .map_init_err(|_| panic!()), ) @@ -159,21 +155,21 @@ mod rustls { impl H1Service, S, B, X, U> where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into>, S::Future: 'static, B: MessageBody, X: ServiceFactory, - X::Error: ResponseError, + X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, Framed, Codec>), + Request = (Request, IoState, Codec), Response = (), >, - U::Error: fmt::Display + ResponseError, + U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, { @@ -190,7 +186,7 @@ mod rustls { > { pipeline_factory( Acceptor::new(config) - .timeout(self.handshake_timeout) + .timeout((self.handshake_timeout as u64) * 1000) .map_err(SslError::Ssl) .map_init_err(|_| panic!()), ) @@ -206,7 +202,7 @@ mod rustls { impl H1Service where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into>, S::InitError: fmt::Debug, S::Future: 'static, @@ -215,7 +211,7 @@ where pub fn expect(self, expect: X1) -> H1Service where X1: ServiceFactory, - X1::Error: ResponseError, + X1::Error: ResponseError + 'static, X1::InitError: fmt::Debug, X1::Future: 'static, { @@ -232,8 +228,8 @@ where pub fn upgrade(self, upgrade: Option) -> H1Service where - U1: ServiceFactory), Response = ()>, - U1::Error: fmt::Display, + U1: ServiceFactory, + U1::Error: fmt::Display + Error + 'static, U1::InitError: fmt::Debug, U1::Future: 'static, { @@ -262,17 +258,17 @@ impl ServiceFactory for H1Service where T: AsyncRead + AsyncWrite + Unpin + 'static, S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into>, S::InitError: fmt::Debug, S::Future: 'static, B: MessageBody, X: ServiceFactory, - X::Error: ResponseError, + X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, - U: ServiceFactory), Response = ()>, - U::Error: fmt::Display + ResponseError, + U: ServiceFactory, + U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, { @@ -328,20 +324,20 @@ pub struct H1ServiceHandler { impl Service for H1ServiceHandler where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + 'static, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, X: Service, - X::Error: ResponseError, - U: Service), Response = ()>, - U::Error: fmt::Display + ResponseError, + X::Error: ResponseError + 'static, + U: Service, + U::Error: fmt::Display + Error + 'static, { type Request = (T, Option); type Response = (); type Error = DispatchError; - type Future = Dispatcher; + type Future = Dispatcher; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { let cfg = self.config.as_ref(); @@ -369,7 +365,7 @@ where upg.poll_ready(cx) .map_err(|e| { log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(Box::new(e)) + DispatchError::Upgrade(Box::new(e)) })? .is_ready() && ready @@ -407,6 +403,6 @@ where None }; - Dispatcher::new(self.config.clone(), io, addr, on_connect) + Dispatcher::new(io, self.config.clone(), addr, on_connect) } } diff --git a/ntex/src/http/h1/upgrade.rs b/ntex/src/http/h1/upgrade.rs index b259c16a..556b2546 100644 --- a/ntex/src/http/h1/upgrade.rs +++ b/ntex/src/http/h1/upgrade.rs @@ -1,10 +1,8 @@ -use std::io; -use std::marker::PhantomData; -use std::task::{Context, Poll}; +use std::{io, marker::PhantomData, task::Context, task::Poll}; use futures::future::Ready; -use crate::codec::Framed; +use crate::framed::State; use crate::http::h1::Codec; use crate::http::request::Request; use crate::{Service, ServiceFactory}; @@ -13,7 +11,7 @@ pub struct UpgradeHandler(PhantomData); impl ServiceFactory for UpgradeHandler { type Config = (); - type Request = (Request, Framed); + type Request = (Request, State, Codec); type Response = (); type Error = io::Error; type Service = UpgradeHandler; @@ -27,7 +25,7 @@ impl ServiceFactory for UpgradeHandler { } impl Service for UpgradeHandler { - type Request = (Request, Framed); + type Request = (Request, State, Codec); type Response = (); type Error = io::Error; type Future = Ready>; diff --git a/ntex/src/http/h2/dispatcher.rs b/ntex/src/http/h2/dispatcher.rs index 6959c907..d50f8bb7 100644 --- a/ntex/src/http/h2/dispatcher.rs +++ b/ntex/src/http/h2/dispatcher.rs @@ -40,7 +40,7 @@ impl Dispatcher where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, { @@ -76,7 +76,7 @@ impl Future for Dispatcher where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, @@ -100,9 +100,7 @@ where } let (parts, body) = req.into_parts(); - let mut req = Request::with_payload(Payload::< - crate::http::payload::PayloadStream, - >::H2( + let mut req = Request::with_payload(Payload::H2( crate::http::h2::Payload::new(body), )); @@ -155,7 +153,7 @@ pin_project_lite::pin_project! { impl ServiceResponse where F: Future>, - E: ResponseError, + E: ResponseError + 'static, I: Into>, B: MessageBody, { @@ -221,7 +219,7 @@ where impl Future for ServiceResponse where F: Future>, - E: ResponseError, + E: ResponseError + 'static, I: Into>, B: MessageBody, { @@ -260,7 +258,7 @@ where } Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => { - let res: Response = e.into(); + let res: Response = (&e).into(); let (res, body) = res.replace_body(()); let mut send = send.take().unwrap(); diff --git a/ntex/src/http/h2/service.rs b/ntex/src/http/h2/service.rs index b74364fb..a6e4f94f 100644 --- a/ntex/src/http/h2/service.rs +++ b/ntex/src/http/h2/service.rs @@ -38,7 +38,7 @@ pub struct H2Service { impl H2Service where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into> + 'static, S::Future: 'static, ::Future: 'static, @@ -52,7 +52,7 @@ where H2Service { on_connect: None, srv: service.into_factory(), - handshake_timeout: cfg.0.ssl_handshake_timeout, + handshake_timeout: (cfg.0.ssl_handshake_timeout as u64) * 1000, _t: PhantomData, cfg, } @@ -71,7 +71,7 @@ where impl H2Service where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into> + 'static, S::Future: 'static, ::Future: 'static, @@ -108,7 +108,7 @@ mod openssl { impl H2Service, S, B> where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into> + 'static, S::Future: 'static, ::Future: 'static, @@ -151,7 +151,7 @@ mod rustls { impl H2Service, S, B> where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into> + 'static, S::Future: 'static, ::Future: 'static, @@ -192,7 +192,7 @@ impl ServiceFactory for H2Service where T: AsyncRead + AsyncWrite + Unpin + 'static, S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Response: Into> + 'static, S::Future: 'static, ::Future: 'static, @@ -236,7 +236,7 @@ impl Service for H2ServiceHandler where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, @@ -295,7 +295,7 @@ pub struct H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, @@ -307,7 +307,7 @@ impl Future for H2ServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody, diff --git a/ntex/src/http/mod.rs b/ntex/src/http/mod.rs index 7f5a2e06..85c065ed 100644 --- a/ntex/src/http/mod.rs +++ b/ntex/src/http/mod.rs @@ -45,7 +45,3 @@ pub enum Protocol { Http1, Http2, } - -#[doc(hidden)] -#[deprecated(since = "0.1.19", note = "Use ntex::util::Extensions instead")] -pub use crate::util::Extensions; diff --git a/ntex/src/http/payload.rs b/ntex/src/http/payload.rs index c079d04e..0441aac4 100644 --- a/ntex/src/http/payload.rs +++ b/ntex/src/http/payload.rs @@ -1,45 +1,44 @@ -use std::fmt; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{fmt, mem, pin::Pin, task::Context, task::Poll}; use bytes::Bytes; use futures::Stream; use h2::RecvStream; use super::error::PayloadError; +use super::{h1, h2 as h2d}; /// Type represent boxed payload pub type PayloadStream = Pin>>>; /// Type represent streaming payload -pub enum Payload { +pub enum Payload { None, - H1(crate::http::h1::Payload), - H2(crate::http::h2::Payload), - Stream(S), + H1(h1::Payload), + H2(h2d::Payload), + Stream(PayloadStream), } -impl Default for Payload { +impl Default for Payload { fn default() -> Self { Payload::None } } -impl From for Payload { - fn from(v: crate::http::h1::Payload) -> Self { +impl From for Payload { + fn from(v: h1::Payload) -> Self { Payload::H1(v) } } -impl From for Payload { - fn from(v: crate::http::h2::Payload) -> Self { +impl From for Payload { + fn from(v: h2d::Payload) -> Self { Payload::H2(v) } } -impl From for Payload { +impl From for Payload { fn from(v: RecvStream) -> Self { - Payload::H2(crate::http::h2::Payload::new(v)) + Payload::H2(h2d::Payload::new(v)) } } @@ -49,7 +48,7 @@ impl From for Payload { } } -impl fmt::Debug for Payload { +impl fmt::Debug for Payload { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Payload::None => write!(f, "Payload::None"), @@ -60,17 +59,22 @@ impl fmt::Debug for Payload { } } -impl Payload { +impl Payload { /// Takes current payload and replaces it with `None` value - pub fn take(&mut self) -> Payload { - std::mem::take(self) + pub fn take(&mut self) -> Self { + mem::take(self) + } + + /// Create payload from stream + pub fn from_stream(stream: S) -> Self + where + S: Stream> + 'static, + { + Payload::Stream(Box::pin(stream)) } } -impl Stream for Payload -where - S: Stream> + Unpin, -{ +impl Stream for Payload { type Item = Result; #[inline] @@ -93,19 +97,12 @@ mod tests { #[test] fn payload_debug() { - assert!( - format!("{:?}", Payload::::None).contains("Payload::None") - ); + assert!(format!("{:?}", Payload::None).contains("Payload::None")); + assert!(format!("{:?}", Payload::H1(h1::Payload::create(false).1)) + .contains("Payload::H1")); assert!(format!( "{:?}", - Payload::::H1(crate::http::h1::Payload::create(false).1) - ) - .contains("Payload::H1")); - assert!(format!( - "{:?}", - Payload::::Stream(Box::pin( - crate::http::h1::Payload::create(false).1 - )) + Payload::Stream(Box::pin(h1::Payload::create(false).1)) ) .contains("Payload::Stream")); } diff --git a/ntex/src/http/request.rs b/ntex/src/http/request.rs index 99412eab..1cbf35ed 100644 --- a/ntex/src/http/request.rs +++ b/ntex/src/http/request.rs @@ -1,21 +1,20 @@ -use std::cell::{Ref, RefMut}; -use std::{fmt, net}; +use std::{cell::Ref, cell::RefMut, fmt, mem, net}; use http::{header, Method, Uri, Version}; use crate::http::header::HeaderMap; use crate::http::httpmessage::HttpMessage; use crate::http::message::{Message, RequestHead}; -use crate::http::payload::{Payload, PayloadStream}; +use crate::http::payload::Payload; use crate::util::Extensions; /// Request -pub struct Request

{ - pub(crate) payload: Payload

, +pub struct Request { + pub(crate) payload: Payload, pub(crate) head: Message, } -impl

HttpMessage for Request

{ +impl HttpMessage for Request { #[inline] fn message_headers(&self) -> &HeaderMap { &self.head().headers @@ -34,7 +33,7 @@ impl

HttpMessage for Request

{ } } -impl From> for Request { +impl From> for Request { fn from(head: Message) -> Self { Request { head, @@ -43,9 +42,9 @@ impl From> for Request { } } -impl Request { +impl Request { /// Create new Request instance - pub fn new() -> Request { + pub fn new() -> Request { Request { head: Message::new(), payload: Payload::None, @@ -53,9 +52,9 @@ impl Request { } } -impl

Request

{ +impl Request { /// Create new Request instance - pub fn with_payload(payload: Payload

) -> Request

{ + pub fn with_payload(payload: Payload) -> Request { Request { payload, head: Message::new(), @@ -137,25 +136,18 @@ impl

Request

{ } /// Get request's payload - pub fn payload(&mut self) -> &mut Payload

{ + pub fn payload(&mut self) -> &mut Payload { &mut self.payload } /// Get request's payload - pub fn take_payload(&mut self) -> Payload

{ - std::mem::take(&mut self.payload) + pub fn take_payload(&mut self) -> Payload { + mem::take(&mut self.payload) } - /// Create new Request instance - pub fn replace_payload(self, payload: Payload) -> (Request, Payload

) { - let pl = self.payload; - ( - Request { - payload, - head: self.head, - }, - pl, - ) + /// Replace request's payload, returns old one + pub fn replace_payload(&mut self, payload: Payload) -> Payload { + mem::replace(&mut self.payload, payload) } /// Request extensions @@ -172,12 +164,12 @@ impl

Request

{ #[allow(dead_code)] /// Split request into request head and payload - pub(crate) fn into_parts(self) -> (Message, Payload

) { + pub(crate) fn into_parts(self) -> (Message, Payload) { (self.head, self.payload) } } -impl

fmt::Debug for Request

{ +impl fmt::Debug for Request { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!( f, diff --git a/ntex/src/http/response.rs b/ntex/src/http/response.rs index 64312270..62c4c9b0 100644 --- a/ntex/src/http/response.rs +++ b/ntex/src/http/response.rs @@ -610,7 +610,7 @@ impl ResponseBuilder { self.body(Body::from(body)) } - Err(e) => e.into(), + Err(e) => (&e).into(), } } @@ -755,7 +755,7 @@ where fn from(res: Result) -> Self { match res { Ok(val) => val.into(), - Err(err) => err.into(), + Err(err) => (&err).into(), } } } diff --git a/ntex/src/http/service.rs b/ntex/src/http/service.rs index fb3e6f68..2886290c 100644 --- a/ntex/src/http/service.rs +++ b/ntex/src/http/service.rs @@ -1,10 +1,13 @@ -use std::{fmt, marker::PhantomData, net, pin::Pin, rc::Rc, task::Context, task::Poll}; +use std::{ + error, fmt, marker::PhantomData, net, pin::Pin, rc::Rc, task::Context, task::Poll, +}; use bytes::Bytes; use futures::future::{ok, Future, FutureExt, LocalBoxFuture}; use h2::server::{self, Handshake}; -use crate::codec::{AsyncRead, AsyncWrite, Framed}; +use crate::codec::{AsyncRead, AsyncWrite}; +use crate::framed::State; use crate::rt::net::TcpStream; use crate::service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory}; @@ -30,7 +33,7 @@ pub struct HttpService impl HttpService where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, S::Future: 'static, @@ -46,7 +49,7 @@ where impl HttpService where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, S::Future: 'static, @@ -86,7 +89,7 @@ where impl HttpService where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, S::Future: 'static, @@ -123,10 +126,10 @@ where where U1: ServiceFactory< Config = (), - Request = (Request, Framed), + Request = (Request, State, h1::Codec), Response = (), >, - U1::Error: fmt::Display, + U1::Error: fmt::Display + error::Error + 'static, U1::InitError: fmt::Debug, U1::Future: 'static, { @@ -153,23 +156,19 @@ where impl HttpService where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, S::Future: 'static, ::Future: 'static, B: MessageBody + 'static, X: ServiceFactory, - X::Error: ResponseError, + X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, - U: ServiceFactory< - Config = (), - Request = (Request, Framed), - Response = (), - >, - U::Error: fmt::Display + ResponseError, + U: ServiceFactory, + U::Error: fmt::Display + error::Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, ::Future: 'static, @@ -201,23 +200,23 @@ mod openssl { impl HttpService, S, B, X, U> where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Response: Into> + 'static, S::Future: 'static, ::Future: 'static, B: MessageBody + 'static, X: ServiceFactory, - X::Error: ResponseError, + X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, Framed, h1::Codec>), + Request = (Request, State, h1::Codec), Response = (), >, - U::Error: fmt::Display + ResponseError, + U::Error: fmt::Display + error::Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, ::Future: 'static, @@ -235,7 +234,7 @@ mod openssl { > { pipeline_factory( Acceptor::new(acceptor) - .timeout(self.cfg.0.ssl_handshake_timeout) + .timeout((self.cfg.0.ssl_handshake_timeout as u64) * 1000) .map_err(SslError::Ssl) .map_init_err(|_| panic!()), ) @@ -266,23 +265,23 @@ mod rustls { impl HttpService, S, B, X, U> where S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Future: 'static, S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, X: ServiceFactory, - X::Error: ResponseError, + X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, Framed, h1::Codec>), + Request = (Request, State, h1::Codec), Response = (), >, - U::Error: fmt::Display + ResponseError, + U::Error: fmt::Display + error::Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, ::Future: 'static, @@ -303,7 +302,7 @@ mod rustls { pipeline_factory( Acceptor::new(config) - .timeout(self.cfg.0.ssl_handshake_timeout) + .timeout((self.cfg.0.ssl_handshake_timeout as u64) * 1000) .map_err(SslError::Ssl) .map_init_err(|_| panic!()), ) @@ -332,23 +331,19 @@ impl ServiceFactory for HttpService where T: AsyncRead + AsyncWrite + Unpin + 'static, S: ServiceFactory, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::InitError: fmt::Debug, S::Future: 'static, S::Response: Into> + 'static, ::Future: 'static, B: MessageBody + 'static, X: ServiceFactory, - X::Error: ResponseError, + X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, - U: ServiceFactory< - Config = (), - Request = (Request, Framed), - Response = (), - >, - U::Error: fmt::Display + ResponseError, + U: ServiceFactory, + U::Error: fmt::Display + error::Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, ::Future: 'static, @@ -407,16 +402,16 @@ pub struct HttpServiceHandler { impl Service for HttpServiceHandler where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + 'static, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody + 'static, X: Service, - X::Error: ResponseError, - U: Service), Response = ()>, - U::Error: fmt::Display + ResponseError, + X::Error: ResponseError + 'static, + U: Service, + U::Error: fmt::Display + error::Error + 'static, { type Request = (T, Protocol, Option); type Response = (); @@ -449,7 +444,7 @@ where upg.poll_ready(cx) .map_err(|e| { log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(Box::new(e)) + DispatchError::Upgrade(Box::new(e)) })? .is_ready() && ready @@ -489,7 +484,7 @@ where match proto { Protocol::Http2 => HttpServiceHandlerResponse { - state: State::H2Handshake { + state: ResponseState::H2Handshake { data: Some(( server::handshake(io), self.config.clone(), @@ -499,10 +494,10 @@ where }, }, Protocol::Http1 => HttpServiceHandlerResponse { - state: State::H1 { + state: ResponseState::H1 { fut: h1::Dispatcher::new( - self.config.clone(), io, + self.config.clone(), peer_addr, on_connect, ), @@ -520,36 +515,44 @@ pin_project_lite::pin_project! { T: Unpin, S: Service, S::Error: ResponseError, + S::Error: 'static, S::Response: Into>, S::Response: 'static, B: MessageBody, B: 'static, X: Service, X::Error: ResponseError, - U: Service), Response = ()>, + X::Error: 'static, + U: Service, U::Error: fmt::Display, + U::Error: error::Error, + U::Error: 'static, { #[pin] - state: State, + state: ResponseState, } } pin_project_lite::pin_project! { #[project = StateProject] - enum State + enum ResponseState where S: Service, S::Error: ResponseError, + S::Error: 'static, T: AsyncRead, T: AsyncWrite, T: Unpin, B: MessageBody, X: Service, X::Error: ResponseError, - U: Service), Response = ()>, + X::Error: 'static, + U: Service, U::Error: fmt::Display, + U::Error: error::Error, + U::Error: 'static, { - H1 { #[pin] fut: h1::Dispatcher }, + H1 { #[pin] fut: h1::Dispatcher }, H2 { fut: Dispatcher }, H2Handshake { data: Option<( @@ -566,14 +569,14 @@ impl Future for HttpServiceHandlerResponse where T: AsyncRead + AsyncWrite + Unpin, S: Service, - S::Error: ResponseError, + S::Error: ResponseError + 'static, S::Future: 'static, S::Response: Into> + 'static, B: MessageBody, X: Service, - X::Error: ResponseError, - U: Service), Response = ()>, - U::Error: fmt::Display, + X::Error: ResponseError + 'static, + U: Service, + U::Error: fmt::Display + error::Error + 'static, { type Output = Result<(), DispatchError>; @@ -597,7 +600,7 @@ where panic!() }; let (_, cfg, on_connect, peer_addr) = data.take().unwrap(); - self.as_mut().project().state.set(State::H2 { + self.as_mut().project().state.set(ResponseState::H2 { fut: Dispatcher::new(cfg, conn, on_connect, None, peer_addr), }); self.poll(cx) diff --git a/ntex/src/http/test.rs b/ntex/src/http/test.rs index 6bd95274..d24e1257 100644 --- a/ntex/src/http/test.rs +++ b/ntex/src/http/test.rs @@ -1,11 +1,7 @@ //! Test helpers to use during testing. -use std::convert::TryFrom; -use std::str::FromStr; -use std::sync::mpsc; -use std::{io, net, thread, time}; +use std::{convert::TryFrom, io, net, str::FromStr, sync::mpsc, thread, time}; use bytes::Bytes; -use futures::Stream; #[cfg(feature = "cookie")] use coo_kie::{Cookie, CookieJar}; @@ -316,13 +312,11 @@ impl TestServer { .request(method, self.surl(path.as_ref()).as_str()) } - pub async fn load_body( + /// Load response's body + pub async fn load_body( &mut self, - mut response: ClientResponse, - ) -> Result - where - S: Stream> + Unpin + 'static, - { + mut response: ClientResponse, + ) -> Result { response.body().limit(10_485_760).await } diff --git a/ntex/src/util/stream.rs b/ntex/src/util/stream.rs index 6d55f7c6..2784983f 100644 --- a/ntex/src/util/stream.rs +++ b/ntex/src/util/stream.rs @@ -189,7 +189,7 @@ mod tests { crate::rt::spawn(disp.map(|_| ())); let mut buf = BytesMut::new(); - let mut codec = ws::Codec::new().client_mode(); + let codec = ws::Codec::new().client_mode(); codec .encode(ws::Message::Text("test".to_string()), &mut buf) .unwrap(); diff --git a/ntex/src/web/app.rs b/ntex/src/web/app.rs index b3b2d1e3..e4feb9d0 100644 --- a/ntex/src/web/app.rs +++ b/ntex/src/web/app.rs @@ -5,12 +5,13 @@ use std::rc::Rc; use futures::future::{FutureExt, LocalBoxFuture}; -use crate::http::{Extensions, Request}; +use crate::http::Request; use crate::router::ResourceDef; use crate::service::boxed::{self, BoxServiceFactory}; use crate::service::{ apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, Transform, }; +use crate::util::Extensions; use super::app_service::{AppEntry, AppFactory, AppRoutingFactory}; use super::config::{AppConfig, ServiceConfig}; diff --git a/ntex/src/web/app_service.rs b/ntex/src/web/app_service.rs index f7e6abb3..cf52f2cc 100644 --- a/ntex/src/web/app_service.rs +++ b/ntex/src/web/app_service.rs @@ -2,9 +2,10 @@ use std::{cell::RefCell, marker::PhantomData, rc::Rc, task::Context, task::Poll} use futures::future::{ok, FutureExt, LocalBoxFuture}; -use crate::http::{Extensions, Request, Response}; +use crate::http::{Request, Response}; use crate::router::{Path, ResourceDef, ResourceInfo, Router}; use crate::service::boxed::{self, BoxService, BoxServiceFactory}; +use crate::util::Extensions; use crate::{fn_service, Service, ServiceFactory}; use super::config::AppConfig; diff --git a/ntex/src/web/httprequest.rs b/ntex/src/web/httprequest.rs index ccb9e6e1..2923b1e7 100644 --- a/ntex/src/web/httprequest.rs +++ b/ntex/src/web/httprequest.rs @@ -5,10 +5,10 @@ use std::{fmt, net}; use futures::future::{ok, Ready}; use crate::http::{ - Extensions, HeaderMap, HttpMessage, Message, Method, Payload, RequestHead, Uri, - Version, + HeaderMap, HttpMessage, Message, Method, Payload, RequestHead, Uri, Version, }; use crate::router::Path; +use crate::util::Extensions; use super::config::AppConfig; use super::error::{ErrorRenderer, UrlGenerationError}; diff --git a/ntex/src/web/request.rs b/ntex/src/web/request.rs index 22a1797b..0e7b85d6 100644 --- a/ntex/src/web/request.rs +++ b/ntex/src/web/request.rs @@ -4,10 +4,10 @@ use std::rc::Rc; use std::{fmt, net}; use crate::http::{ - header, Extensions, HeaderMap, HttpMessage, Method, Payload, PayloadStream, - RequestHead, Response, Uri, Version, + header, HeaderMap, HttpMessage, Method, Payload, RequestHead, Response, Uri, Version, }; use crate::router::{Path, Resource}; +use crate::util::Extensions; use super::config::AppConfig; use super::error::{ErrorRenderer, WebResponseError}; @@ -204,7 +204,7 @@ impl WebRequest { #[inline] /// Get request's payload - pub fn take_payload(&mut self) -> Payload { + pub fn take_payload(&mut self) -> Payload { Rc::get_mut(&mut (self.req).0).unwrap().payload.take() } diff --git a/ntex/src/web/resource.rs b/ntex/src/web/resource.rs index 222c55d0..a59af9f4 100644 --- a/ntex/src/web/resource.rs +++ b/ntex/src/web/resource.rs @@ -2,12 +2,13 @@ use std::{cell::RefCell, fmt, rc::Rc, task::Context, task::Poll}; use futures::future::{ok, Either, Future, FutureExt, LocalBoxFuture, Ready}; -use crate::http::{Extensions, Response}; +use crate::http::Response; use crate::router::{IntoPattern, ResourceDef}; use crate::service::boxed::{self, BoxService, BoxServiceFactory}; use crate::service::{ apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, }; +use crate::util::Extensions; use super::dev::{insert_slesh, WebServiceConfig, WebServiceFactory}; use super::error::ErrorRenderer; diff --git a/ntex/src/web/scope.rs b/ntex/src/web/scope.rs index 1c96e18d..a6d1dffd 100644 --- a/ntex/src/web/scope.rs +++ b/ntex/src/web/scope.rs @@ -2,12 +2,13 @@ use std::{cell::RefCell, fmt, rc::Rc, task::Context, task::Poll}; use futures::future::{ok, Either, Future, FutureExt, LocalBoxFuture, Ready}; -use crate::http::{Extensions, Response}; +use crate::http::Response; use crate::router::{ResourceDef, ResourceInfo, Router}; use crate::service::boxed::{self, BoxService, BoxServiceFactory}; use crate::service::{ apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform, }; +use crate::util::Extensions; use super::config::ServiceConfig; use super::dev::{WebServiceConfig, WebServiceFactory}; diff --git a/ntex/src/web/server.rs b/ntex/src/web/server.rs index c21ef475..119cc16b 100644 --- a/ntex/src/web/server.rs +++ b/ntex/src/web/server.rs @@ -24,9 +24,9 @@ use super::config::AppConfig; struct Config { host: Option, keep_alive: KeepAlive, - client_timeout: u64, - client_disconnect: u64, - handshake_timeout: u64, + client_timeout: u16, + client_disconnect: u16, + handshake_timeout: u16, } /// An HTTP Server. @@ -148,7 +148,7 @@ where self } - /// Set server client timeout in milliseconds for first request. + /// Set server client timeout in seconds for first request. /// /// Defines a timeout for reading client request header. If a client does not transmit /// the entire set headers within this time, the request is terminated with @@ -157,12 +157,12 @@ where /// To disable timeout set value to 0. /// /// By default client timeout is set to 5 seconds. - pub fn client_timeout(self, val: u64) -> Self { + pub fn client_timeout(self, val: u16) -> Self { self.config.lock().unwrap().client_timeout = val; self } - /// Set server connection disconnect timeout in milliseconds. + /// Set server connection disconnect timeout in seconds. /// /// Defines a timeout for shutdown connection. If a shutdown procedure does not complete /// within this time, the request is dropped. @@ -170,18 +170,18 @@ where /// To disable timeout set value to 0. /// /// By default client timeout is set to 5 seconds. - pub fn disconnect_timeout(self, val: u64) -> Self { + pub fn disconnect_timeout(self, val: u16) -> Self { self.config.lock().unwrap().client_disconnect = val; self } - /// Set server ssl handshake timeout in milliseconds. + /// Set server ssl handshake timeout in seconds. /// /// Defines a timeout for connection ssl handshake negotiation. /// To disable timeout set value to 0. /// /// By default handshake timeout is set to 5 seconds. - pub fn ssl_handshake_timeout(self, val: u64) -> Self { + pub fn ssl_handshake_timeout(self, val: u16) -> Self { self.config.lock().unwrap().handshake_timeout = val; self } diff --git a/ntex/src/web/service.rs b/ntex/src/web/service.rs index eabaa0d2..7a428e5c 100644 --- a/ntex/src/web/service.rs +++ b/ntex/src/web/service.rs @@ -1,8 +1,8 @@ use std::rc::Rc; -use crate::http::Extensions; use crate::router::{IntoPattern, ResourceDef}; use crate::service::{boxed, IntoServiceFactory, ServiceFactory}; +use crate::util::Extensions; use super::config::AppConfig; use super::dev::insert_slesh; diff --git a/ntex/src/web/test.rs b/ntex/src/web/test.rs index bf6308f8..c29fbf45 100644 --- a/ntex/src/web/test.rs +++ b/ntex/src/web/test.rs @@ -1,10 +1,8 @@ //! Various helpers for ntex applications to use during testing. -use std::convert::TryFrom; -use std::error::Error; -use std::net::SocketAddr; -use std::rc::Rc; -use std::sync::mpsc; -use std::{fmt, net, thread, time}; +use std::{ + convert::TryFrom, error::Error, fmt, net, net::SocketAddr, rc::Rc, sync::mpsc, + thread, time, +}; use bytes::{Bytes, BytesMut}; use futures::future::ok; @@ -22,12 +20,11 @@ use crate::http::client::{Client, ClientRequest, ClientResponse, Connector}; use crate::http::error::{HttpError, PayloadError, ResponseError}; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; use crate::http::test::TestRequest as HttpTestRequest; -use crate::http::{ - Extensions, HttpService, Method, Payload, Request, StatusCode, Uri, Version, -}; +use crate::http::{HttpService, Method, Payload, Request, StatusCode, Uri, Version}; use crate::router::{Path, ResourceDef}; use crate::rt::{time::delay_for, System}; use crate::server::Server; +use crate::util::Extensions; use crate::{map_config, IntoService, IntoServiceFactory, Service, ServiceFactory}; use crate::web::config::AppConfig; @@ -776,7 +773,7 @@ where pub struct TestServerConfig { tp: HttpVer, stream: StreamType, - client_timeout: u64, + client_timeout: u16, } #[derive(Clone, Debug)] @@ -854,8 +851,8 @@ impl TestServerConfig { self } - /// Set server client timeout in milliseconds for first request. - pub fn client_timeout(mut self, val: u64) -> Self { + /// Set server client timeout in seconds for first request. + pub fn client_timeout(mut self, val: u16) -> Self { self.client_timeout = val; self } @@ -927,13 +924,11 @@ impl TestServer { self.client.request(method, path.as_ref()) } - pub async fn load_body( + /// Load response's body + pub async fn load_body( &self, - mut response: ClientResponse, - ) -> Result - where - S: Stream> + Unpin + 'static, - { + mut response: ClientResponse, + ) -> Result { response.body().limit(10_485_760).await } diff --git a/ntex/src/web/types/data.rs b/ntex/src/web/types/data.rs index 6f9ccf3f..194b18b3 100644 --- a/ntex/src/web/types/data.rs +++ b/ntex/src/web/types/data.rs @@ -1,9 +1,9 @@ -use std::ops::Deref; -use std::sync::Arc; +use std::{ops::Deref, sync::Arc}; use futures::future::{err, ok, Ready}; -use crate::http::{Extensions, Payload}; +use crate::http::Payload; +use crate::util::Extensions; use crate::web::error::{DataExtractorError, ErrorRenderer}; use crate::web::extract::FromRequest; use crate::web::httprequest::HttpRequest; diff --git a/ntex/src/ws/codec.rs b/ntex/src/ws/codec.rs index 78e71397..07c9d67b 100644 --- a/ntex/src/ws/codec.rs +++ b/ntex/src/ws/codec.rs @@ -1,4 +1,5 @@ use bytes::{Bytes, BytesMut}; +use std::cell::Cell; use crate::codec::{Decoder, Encoder}; @@ -49,10 +50,10 @@ pub enum Item { Last(Bytes), } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone)] /// WebSockets protocol codec pub struct Codec { - flags: Flags, + flags: Cell, max_size: usize, } @@ -69,7 +70,7 @@ impl Codec { pub fn new() -> Codec { Codec { max_size: 65_536, - flags: Flags::SERVER, + flags: Cell::new(Flags::SERVER), } } @@ -84,10 +85,22 @@ impl Codec { /// Set decoder to client mode. /// /// By default decoder works in server mode. - pub fn client_mode(mut self) -> Self { - self.flags.remove(Flags::SERVER); + pub fn client_mode(self) -> Self { + self.remove_flags(Flags::SERVER); self } + + fn insert_flags(&self, f: Flags) { + let mut flags = self.flags.get(); + flags.insert(f); + self.flags.set(flags); + } + + fn remove_flags(&self, f: Flags) { + let mut flags = self.flags.get(); + flags.remove(f); + self.flags.set(flags); + } } impl Default for Codec { @@ -100,90 +113,92 @@ impl Encoder for Codec { type Item = Message; type Error = ProtocolError; - fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { + fn encode(&self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { Message::Text(txt) => Parser::write_message( dst, txt, OpCode::Text, true, - !self.flags.contains(Flags::SERVER), + !self.flags.get().contains(Flags::SERVER), ), Message::Binary(bin) => Parser::write_message( dst, bin, OpCode::Binary, true, - !self.flags.contains(Flags::SERVER), + !self.flags.get().contains(Flags::SERVER), ), Message::Ping(txt) => Parser::write_message( dst, txt, OpCode::Ping, true, - !self.flags.contains(Flags::SERVER), + !self.flags.get().contains(Flags::SERVER), ), Message::Pong(txt) => Parser::write_message( dst, txt, OpCode::Pong, true, - !self.flags.contains(Flags::SERVER), + !self.flags.get().contains(Flags::SERVER), + ), + Message::Close(reason) => Parser::write_close( + dst, + reason, + !self.flags.get().contains(Flags::SERVER), ), - Message::Close(reason) => { - Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER)) - } Message::Continuation(cont) => match cont { Item::FirstText(data) => { - if self.flags.contains(Flags::W_CONTINUATION) { + if self.flags.get().contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { - self.flags.insert(Flags::W_CONTINUATION); + self.insert_flags(Flags::W_CONTINUATION); Parser::write_message( dst, &data[..], OpCode::Text, false, - !self.flags.contains(Flags::SERVER), + !self.flags.get().contains(Flags::SERVER), ) } } Item::FirstBinary(data) => { - if self.flags.contains(Flags::W_CONTINUATION) { + if self.flags.get().contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { - self.flags.insert(Flags::W_CONTINUATION); + self.insert_flags(Flags::W_CONTINUATION); Parser::write_message( dst, &data[..], OpCode::Binary, false, - !self.flags.contains(Flags::SERVER), + !self.flags.get().contains(Flags::SERVER), ) } } Item::Continue(data) => { - if self.flags.contains(Flags::W_CONTINUATION) { + if self.flags.get().contains(Flags::W_CONTINUATION) { Parser::write_message( dst, &data[..], OpCode::Continue, false, - !self.flags.contains(Flags::SERVER), + !self.flags.get().contains(Flags::SERVER), ) } else { return Err(ProtocolError::ContinuationNotStarted); } } Item::Last(data) => { - if self.flags.contains(Flags::W_CONTINUATION) { - self.flags.remove(Flags::W_CONTINUATION); + if self.flags.get().contains(Flags::W_CONTINUATION) { + self.remove_flags(Flags::W_CONTINUATION); Parser::write_message( dst, &data[..], OpCode::Continue, true, - !self.flags.contains(Flags::SERVER), + !self.flags.get().contains(Flags::SERVER), ) } else { return Err(ProtocolError::ContinuationNotStarted); @@ -199,14 +214,15 @@ impl Decoder for Codec { type Item = Frame; type Error = ProtocolError; - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) { + fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { + match Parser::parse(src, self.flags.get().contains(Flags::SERVER), self.max_size) + { Ok(Some((finished, opcode, payload))) => { // handle continuation if !finished { return match opcode { OpCode::Continue => { - if self.flags.contains(Flags::CONTINUATION) { + if self.flags.get().contains(Flags::CONTINUATION) { Ok(Some(Frame::Continuation(Item::Continue( payload .map(|pl| pl.freeze()) @@ -217,8 +233,8 @@ impl Decoder for Codec { } } OpCode::Binary => { - if !self.flags.contains(Flags::CONTINUATION) { - self.flags.insert(Flags::CONTINUATION); + if !self.flags.get().contains(Flags::CONTINUATION) { + self.insert_flags(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstBinary( payload .map(|pl| pl.freeze()) @@ -229,8 +245,8 @@ impl Decoder for Codec { } } OpCode::Text => { - if !self.flags.contains(Flags::CONTINUATION) { - self.flags.insert(Flags::CONTINUATION); + if !self.flags.get().contains(Flags::CONTINUATION) { + self.insert_flags(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstText( payload .map(|pl| pl.freeze()) @@ -249,8 +265,8 @@ impl Decoder for Codec { match opcode { OpCode::Continue => { - if self.flags.contains(Flags::CONTINUATION) { - self.flags.remove(Flags::CONTINUATION); + if self.flags.get().contains(Flags::CONTINUATION) { + self.remove_flags(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::Last( payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), )))) diff --git a/ntex/src/ws/stream.rs b/ntex/src/ws/stream.rs index c924b401..72fdd830 100644 --- a/ntex/src/ws/stream.rs +++ b/ntex/src/ws/stream.rs @@ -187,7 +187,7 @@ mod tests { let mut decoder = StreamDecoder::new(rx); let mut buf = BytesMut::new(); - let mut codec = Codec::new().client_mode(); + let codec = Codec::new().client_mode(); codec .encode(Message::Text("test1".to_string()), &mut buf) .unwrap(); diff --git a/ntex/tests/http_awc_ws.rs b/ntex/tests/http_awc_ws.rs index 924e8179..c8ad2b42 100644 --- a/ntex/tests/http_awc_ws.rs +++ b/ntex/tests/http_awc_ws.rs @@ -4,41 +4,58 @@ use bytes::Bytes; use futures::future::ok; use futures::{SinkExt, StreamExt}; -use ntex::codec::Framed; +use ntex::framed::{DispatchItem, Dispatcher, State}; use ntex::http::test::server as test_server; use ntex::http::ws::handshake_response; use ntex::http::{body::BodySize, h1, HttpService, Request, Response}; -use ntex::util::framed::Dispatcher; use ntex::ws; -async fn ws_service(req: ws::Frame) -> Result, io::Error> { - let item = match req { - ws::Frame::Ping(msg) => ws::Message::Pong(msg), - ws::Frame::Text(text) => { - ws::Message::Text(String::from_utf8(Vec::from(text.as_ref())).unwrap()) - } - ws::Frame::Binary(bin) => ws::Message::Binary(bin), - ws::Frame::Close(reason) => ws::Message::Close(reason), - _ => ws::Message::Close(None), +async fn ws_service( + msg: DispatchItem, +) -> Result, io::Error> { + println!("TEST: {:?}", msg); + let msg = match msg { + DispatchItem::Item(msg) => match msg { + ws::Frame::Ping(msg) => ws::Message::Pong(msg), + ws::Frame::Text(text) => { + ws::Message::Text(String::from_utf8(Vec::from(text.as_ref())).unwrap()) + } + ws::Frame::Binary(bin) => ws::Message::Binary(bin), + ws::Frame::Close(reason) => ws::Message::Close(reason), + _ => ws::Message::Close(None), + }, + _ => return Ok(None), }; - Ok(Some(item)) + Ok(Some(msg)) } #[ntex::test] async fn test_simple() { + std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace"); + env_logger::init(); + let mut srv = test_server(|| { HttpService::build() - .upgrade(|(req, mut framed): (Request, Framed<_, _>)| { + .upgrade(|(req, state, mut codec): (Request, State, h1::Codec)| { async move { let res = handshake_response(req.head()).finish(); - // send handshake response - framed - .send(h1::Message::Item((res.drop_body(), BodySize::None))) - .await?; + + // send handshake respone + state + .write_item( + h1::Message::Item((res.drop_body(), BodySize::None)), + &mut codec, + ) + .unwrap(); // start websocket service - let framed = framed.into_framed(ws::Codec::default()); - Dispatcher::new(framed, ws_service).await + Dispatcher::from_state( + ws::Codec::default(), + state, + ws_service, + Default::default(), + ) + .await } }) .finish(|_| ok::<_, io::Error>(Response::NotFound())) diff --git a/ntex/tests/http_openssl.rs b/ntex/tests/http_openssl.rs index a4c52f0d..1419e20e 100644 --- a/ntex/tests/http_openssl.rs +++ b/ntex/tests/http_openssl.rs @@ -114,7 +114,7 @@ async fn test_h2_body() -> io::Result<()> { let data = "HELLOWORLD".to_owned().repeat(64 * 1024); let mut srv = test_server(move || { HttpService::build() - .h2(|mut req: Request<_>| async move { + .h2(|mut req: Request| async move { let body = load_body(req.take_payload()) .await .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; @@ -443,7 +443,7 @@ async fn test_ssl_handshake_timeout() { let srv = test_server(move || { HttpService::build() - .ssl_handshake_timeout(50) + .ssl_handshake_timeout(1) .h2(|_| ok::<_, io::Error>(Response::Ok().finish())) .openssl(ssl_acceptor()) .map_err(|_| ()) diff --git a/ntex/tests/http_rustls.rs b/ntex/tests/http_rustls.rs index 1d870d3d..bab56192 100644 --- a/ntex/tests/http_rustls.rs +++ b/ntex/tests/http_rustls.rs @@ -105,7 +105,7 @@ async fn test_h2_body1() -> io::Result<()> { let data = "HELLOWORLD".to_owned().repeat(64 * 1024); let mut srv = test_server(move || { HttpService::build() - .h2(|mut req: Request<_>| async move { + .h2(|mut req: Request| async move { let body = load_body(req.take_payload()) .await .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; @@ -446,7 +446,7 @@ async fn test_ssl_handshake_timeout() { let srv = test_server(move || { HttpService::build() - .ssl_handshake_timeout(50) + .ssl_handshake_timeout(1) .h2(|_| ok::<_, io::Error>(Response::Ok().finish())) .rustls(ssl_acceptor()) }); diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index 1046f82b..a457ce9a 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -178,7 +178,7 @@ async fn test_chunked_payload() { async fn test_slow_request() { let srv = test_server(|| { HttpService::build() - .client_timeout(100) + .client_timeout(1) .finish(|_| future::ok::<_, io::Error>(Response::Ok().finish())) .tcp() }); diff --git a/ntex/tests/http_ws.rs b/ntex/tests/http_ws.rs index c1dec407..94b396d3 100644 --- a/ntex/tests/http_ws.rs +++ b/ntex/tests/http_ws.rs @@ -1,47 +1,39 @@ -use std::cell::Cell; -use std::io; -use std::marker::PhantomData; -use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; +use std::{cell::Cell, io, pin::Pin}; use bytes::Bytes; use futures::{future, Future, SinkExt, StreamExt}; -use ntex::codec::{AsyncRead, AsyncWrite, Framed}; -use ntex::http::ws::handshake; -use ntex::http::{body, h1, test, HttpService, Request, Response}; +use ntex::framed::{DispatchItem, Dispatcher, State, Timer}; +use ntex::http::{body, h1, test, ws::handshake, HttpService, Request, Response}; use ntex::service::{fn_factory, Service}; -use ntex::util::framed::Dispatcher; use ntex::ws; -struct WsService(Arc, Cell)>>); +struct WsService(Arc>>); -impl WsService { +impl WsService { fn new() -> Self { - WsService(Arc::new(Mutex::new((PhantomData, Cell::new(false))))) + WsService(Arc::new(Mutex::new(Cell::new(false)))) } fn set_polled(&self) { - *self.0.lock().unwrap().1.get_mut() = true; + *self.0.lock().unwrap().get_mut() = true; } fn was_polled(&self) -> bool { - self.0.lock().unwrap().1.get() + self.0.lock().unwrap().get() } } -impl Clone for WsService { +impl Clone for WsService { fn clone(&self) -> Self { WsService(self.0.clone()) } } -impl Service for WsService -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - type Request = (Request, Framed); +impl Service for WsService { + type Request = (Request, State, h1::Codec); type Response = (); type Error = io::Error; type Future = Pin>>>; @@ -51,16 +43,15 @@ where Poll::Ready(Ok(())) } - fn call(&self, (req, mut framed): Self::Request) -> Self::Future { + fn call(&self, (req, state, mut codec): Self::Request) -> Self::Future { let fut = async move { let res = handshake(req.head()).unwrap().message_body(()); - framed - .send((res, body::BodySize::None).into()) - .await + state + .write_item((res, body::BodySize::None).into(), &mut codec) .unwrap(); - Dispatcher::new(framed.into_framed(ws::Codec::new()), service) + Dispatcher::from_state(ws::Codec::new(), state, service, Timer::default()) .await .map_err(|_| panic!()) }; @@ -69,16 +60,21 @@ where } } -async fn service(msg: ws::Frame) -> Result, io::Error> { +async fn service( + msg: DispatchItem, +) -> Result, io::Error> { let msg = match msg { - ws::Frame::Ping(msg) => ws::Message::Pong(msg), - ws::Frame::Text(text) => { - ws::Message::Text(String::from_utf8_lossy(&text).to_string()) - } - ws::Frame::Binary(bin) => ws::Message::Binary(bin), - ws::Frame::Continuation(item) => ws::Message::Continuation(item), - ws::Frame::Close(reason) => ws::Message::Close(reason), - _ => panic!(), + DispatchItem::Item(msg) => match msg { + ws::Frame::Ping(msg) => ws::Message::Pong(msg), + ws::Frame::Text(text) => { + ws::Message::Text(String::from_utf8_lossy(&text).to_string()) + } + ws::Frame::Binary(bin) => ws::Message::Binary(bin), + ws::Frame::Continuation(item) => ws::Message::Continuation(item), + ws::Frame::Close(reason) => ws::Message::Close(reason), + _ => panic!(), + }, + _ => return Ok(None), }; Ok(Some(msg)) } diff --git a/ntex/tests/web_server.rs b/ntex/tests/web_server.rs index 9b598628..4d026360 100644 --- a/ntex/tests/web_server.rs +++ b/ntex/tests/web_server.rs @@ -1054,7 +1054,7 @@ async fn test_server_cookies() { async fn test_slow_request() { use std::net; - let srv = test::server_with(test::config().client_timeout(200), || { + let srv = test::server_with(test::config().client_timeout(1), || { App::new() .service(web::resource("/").route(web::to(|| async { HttpResponse::Ok() }))) });