diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index b0fe73f2..a3580e3e 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.2.0-b.6] - 2021-01-24 + +* http: Pass io stream to upgrade handler + ## [0.2.0-b.5] - 2021-01-23 * accept shared ref in some methods of framed::State type diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 7320e81f..062b97e7 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.2.0-b.5" +version = "0.2.0-b.6" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" diff --git a/ntex/src/framed/read.rs b/ntex/src/framed/read.rs index f2736b25..3c094b23 100644 --- a/ntex/src/framed/read.rs +++ b/ntex/src/framed/read.rs @@ -39,6 +39,9 @@ where if self.state.is_io_shutdown() { log::trace!("read task is instructed to shutdown"); Poll::Ready(()) + } else if self.state.is_io_stop() { + self.state.dsp_wake_task(); + Poll::Ready(()) } else if self.state.is_read_paused() { self.state.register_read_task(cx.waker()); Poll::Pending diff --git a/ntex/src/framed/state.rs b/ntex/src/framed/state.rs index 1b6b52c2..4dbb9e95 100644 --- a/ntex/src/framed/state.rs +++ b/ntex/src/framed/state.rs @@ -13,22 +13,26 @@ use crate::task::LocalWaker; const HW: usize = 8 * 1024; bitflags::bitflags! { - pub struct Flags: u8 { + pub struct Flags: u16 { const DSP_STOP = 0b0000_0001; const DSP_KEEPALIVE = 0b0000_0010; + /// io error occured const IO_ERR = 0b0000_0100; - const IO_SHUTDOWN = 0b0000_1000; + /// stop io tasks + const IO_STOP = 0b0000_1000; + /// shutdown io tasks + const IO_SHUTDOWN = 0b0001_0000; /// pause io read - const RD_PAUSED = 0b0001_0000; + const RD_PAUSED = 0b0010_0000; /// new data is available - const RD_READY = 0b0010_0000; + const RD_READY = 0b0100_0000; /// write buffer is full - const WR_NOT_READY = 0b0100_0000; + const WR_NOT_READY = 0b1000_0000; - const ST_DSP_ERR = 0b1000_0000; + const ST_DSP_ERR = 0b10000_0000; } } @@ -148,6 +152,11 @@ impl State { .intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) } + #[inline] + pub fn is_io_stop(&self) -> bool { + self.0.flags.get().contains(Flags::IO_STOP) + } + #[inline] /// Check if read buffer has new data pub fn is_read_ready(&self) -> bool { @@ -317,6 +326,17 @@ impl State { self.0.dispatch_task.register(waker); } + #[inline] + /// Stop io tasks + pub fn dsp_stop_io(&self, waker: &Waker) { + let mut flags = self.0.flags.get(); + flags.insert(Flags::IO_STOP); + self.0.flags.set(flags); + self.0.read_task.wake(); + self.0.write_task.wake(); + self.0.dispatch_task.register(waker); + } + #[inline] /// Wake dispatcher pub fn dsp_wake_task(&self) { @@ -329,6 +349,14 @@ impl State { self.0.dispatch_task.register(waker); } + #[inline] + /// Reset io stop flags + pub fn reset_io_stop(&self) { + let mut flags = self.0.flags.get(); + flags.remove(Flags::IO_STOP); + self.0.flags.set(flags); + } + fn mark_io_error(&self) { self.0.read_task.wake(); self.0.write_task.wake(); diff --git a/ntex/src/framed/write.rs b/ntex/src/framed/write.rs index a8edae20..5d48bb6e 100644 --- a/ntex/src/framed/write.rs +++ b/ntex/src/framed/write.rs @@ -74,6 +74,9 @@ where if this.state.is_io_err() { log::trace!("write io is closed"); return Poll::Ready(()); + } else if this.state.is_io_stop() { + self.state.dsp_wake_task(); + return Poll::Ready(()); } match this.st { @@ -224,7 +227,7 @@ where } } } - // log::trace!("flushed {} bytes", written); + log::trace!("flushed {} bytes", written); // remove written data if written == len { diff --git a/ntex/src/http/builder.rs b/ntex/src/http/builder.rs index 05c321be..c86aa7ef 100644 --- a/ntex/src/http/builder.rs +++ b/ntex/src/http/builder.rs @@ -55,7 +55,7 @@ where X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, - U: ServiceFactory, + U: ServiceFactory, U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, @@ -142,7 +142,7 @@ where F: IntoServiceFactory, U1: ServiceFactory< Config = (), - Request = (Request, State, Codec), + Request = (Request, T, State, Codec), Response = (), >, U1::Error: fmt::Display + Error + 'static, diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index e6ea85b0..07f80b38 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1,21 +1,20 @@ //! Framed transport dispatcher -use std::error::Error; use std::task::{Context, Poll}; use std::{ - cell::RefCell, fmt, future::Future, marker::PhantomData, net, pin::Pin, rc::Rc, - time::Duration, time::Instant, + cell::RefCell, error::Error, fmt, marker::PhantomData, net, pin::Pin, rc::Rc, time, }; use bytes::Bytes; +use futures::Future; -use crate::codec::{AsyncRead, AsyncWrite, Decoder}; +use crate::codec::{AsyncRead, AsyncWrite}; use crate::framed::{ReadTask, State as IoState, WriteTask}; use crate::service::Service; use crate::http; use crate::http::body::{BodySize, MessageBody, ResponseBody}; use crate::http::config::DispatcherConfig; -use crate::http::error::{DispatchError, PayloadError, ResponseError}; +use crate::http::error::{DispatchError, ParseError, PayloadError, ResponseError}; use crate::http::helpers::DataFactory; use crate::http::request::Request; use crate::http::response::Response; @@ -37,11 +36,11 @@ bitflags::bitflags! { pin_project_lite::pin_project! { /// Dispatcher for HTTP/1.1 protocol - pub struct Dispatcher { + pub struct Dispatcher { #[pin] call: CallState, st: State, - inner: DispatcherInner, + inner: DispatcherInner, } } @@ -50,6 +49,7 @@ enum State { ReadRequest, ReadPayload, SendPayload { body: ResponseBody }, + Upgrade(Option), Stop, } @@ -63,12 +63,13 @@ pin_project_lite::pin_project! { } } -struct DispatcherInner { +struct DispatcherInner { + io: Option>>, flags: Flags, codec: Codec, config: Rc>, state: IoState, - expire: Instant, + expire: time::Instant, error: Option, payload: Option<(PayloadDecoder, PayloadSender)>, peer_addr: Option, @@ -77,34 +78,38 @@ struct DispatcherInner { } #[derive(Copy, Clone, PartialEq, Eq)] -enum PollPayloadStatus { +enum ReadPayloadStatus { Done, Updated, Pending, Dropped, } -impl Dispatcher +enum WritePayloadStatus { + Next(State), + Pause, + Continue, +} + +impl Dispatcher where + T: AsyncRead + AsyncWrite + Unpin + 'static, S: Service, S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, X: Service, X::Error: ResponseError, - U: Service, + U: Service, U::Error: Error + fmt::Display, { /// Construct new `Dispatcher` instance with outgoing messages stream. - pub(in crate::http) fn new( + pub(in crate::http) fn new( io: T, config: Rc>, peer_addr: Option, on_connect_data: Option>, - ) -> Self - where - T: AsyncRead + AsyncWrite + Unpin + 'static, - { + ) -> Self { let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled()); let state = IoState::new(); @@ -115,18 +120,19 @@ where // slow-request timer if config.client_timeout != 0 { - expire += Duration::from_secs(config.client_timeout); + expire += time::Duration::from_secs(config.client_timeout); config.timer_h1.register(expire, expire, &state); } // start support io tasks crate::rt::spawn(ReadTask::new(io.clone(), state.clone())); - crate::rt::spawn(WriteTask::new(io, state.clone())); + crate::rt::spawn(WriteTask::new(io.clone(), state.clone())); Dispatcher { call: CallState::None, st: State::ReadRequest, inner: DispatcherInner { + io: Some(io), flags: Flags::empty(), error: None, payload: None, @@ -142,15 +148,16 @@ where } } -impl Future for Dispatcher +impl Future for Dispatcher where + T: AsyncRead + AsyncWrite + Unpin + 'static, S: Service, S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, X: Service, X::Error: ResponseError + 'static, - U: Service, + U: Service, U::Error: Error + fmt::Display + 'static, { type Output = Result<(), DispatchError>; @@ -180,7 +187,7 @@ where // we might need to read more data into a request payload // (ie service future can wait for payload data) if this.inner.poll_read_payload(cx) - != PollPayloadStatus::Updated + != ReadPayloadStatus::Updated { return Poll::Pending; } @@ -197,26 +204,15 @@ where b"HTTP/1.1 100 Continue\r\n\r\n", ) }); - Some(if this.inner.flags.contains(Flags::UPGRADE) { - // Handle UPGRADE request - CallState::Upgrade { - fut: this - .inner - .config - .upgrade - .as_ref() - .unwrap() - .call(( - req, - this.inner.state.clone(), - this.inner.codec.clone(), - )), - } + if this.inner.flags.contains(Flags::UPGRADE) { + this.inner.state.dsp_stop_io(cx.waker()); + *this.st = State::Upgrade(Some(req)); + return Poll::Pending; } else { - CallState::Service { + Some(CallState::Service { fut: this.inner.config.service.call(req), - } - }) + }) + } } Err(e) => { *this.st = this.inner.handle_error(e, true); @@ -273,7 +269,11 @@ where if this.inner.state.is_read_ready() { match this.inner.state.decode_item(&this.inner.codec) { Ok(Some((mut req, pl))) => { - log::trace!("http message is received: {:?}", req); + log::trace!( + "http message is received: {:?} and payload {:?}", + req, + pl + ); req.head_mut().peer_addr = this.inner.peer_addr; // configure request payload @@ -313,48 +313,46 @@ where on_connect.set(&mut req.extensions_mut()); } - // call service - *this.st = State::Call; - this.call.set(if req.head().expect() { + if req.head().expect() { + // call service + *this.st = State::Call; // Handle `EXPECT: 100-Continue` header - CallState::Expect { + this.call.set(CallState::Expect { fut: this.inner.config.expect.call(req), - } + }); } else if upgrade { - log::trace!("initate upgrade handling"); + log::trace!("prep io for upgrade handler"); // Handle UPGRADE request - CallState::Upgrade { - fut: this - .inner - .config - .upgrade - .as_ref() - .unwrap() - .call(( - req, - this.inner.state.clone(), - this.inner.codec.clone(), - )), - } + this.inner.state.dsp_stop_io(cx.waker()); + *this.st = State::Upgrade(Some(req)); + return Poll::Pending; } else { // Handle normal requests - CallState::Service { + *this.st = State::Call; + this.call.set(CallState::Service { fut: this.inner.config.service.call(req), - } - }); + }); + } } Ok(None) => { - // if connection is not keep-alive then disconnect + log::trace!("not enough data to decode next frame, register dispatch task"); + + // if io error occured or connection is not keep-alive + // then disconnect if this.inner.flags.contains(Flags::STARTED) - && !this.inner.flags.contains(Flags::KEEPALIVE) + && (!this.inner.flags.contains(Flags::KEEPALIVE) + || !this.inner.codec.keepalive_enabled() + || this.inner.state.is_io_err()) { *this.st = State::Stop; + this.inner.state.dsp_mark_stopped(); continue; } this.inner.state.dsp_read_more_data(cx.waker()); return Poll::Pending; } Err(err) => { + log::trace!("malformed request: {:?}", err); // Malformed requests, respond with 400 let (res, body) = Response::BadRequest().finish().into_parts(); @@ -379,31 +377,70 @@ where // consume request's payload State::ReadPayload => loop { match this.inner.poll_read_payload(cx) { - PollPayloadStatus::Updated => continue, - PollPayloadStatus::Pending => return Poll::Pending, - PollPayloadStatus::Done => { + ReadPayloadStatus::Updated => continue, + ReadPayloadStatus::Pending => return Poll::Pending, + ReadPayloadStatus::Done => { *this.st = { this.inner.reset_keepalive(); State::ReadRequest } } - PollPayloadStatus::Dropped => *this.st = State::Stop, + ReadPayloadStatus::Dropped => *this.st = State::Stop, } break; }, // send response body State::SendPayload { ref mut body } => { - this.inner.poll_read_payload(cx); + if this.inner.state.is_io_err() { + *this.st = State::Stop; + } else { + this.inner.poll_read_payload(cx); - match body.poll_next_chunk(cx) { - Poll::Ready(item) => { - if let Some(st) = this.inner.send_payload(item) { - *this.st = st; - } + match body.poll_next_chunk(cx) { + Poll::Ready(item) => match this.inner.send_payload(item) { + WritePayloadStatus::Next(st) => { + *this.st = st; + } + WritePayloadStatus::Pause => { + this.inner.state.dsp_flush_write_data(cx.waker()); + return Poll::Pending; + } + WritePayloadStatus::Continue => (), + }, + Poll::Pending => return Poll::Pending, } - Poll::Pending => return Poll::Pending, } } + // stop io tasks and call upgrade service + State::Upgrade(ref mut req) => { + // check if all io tasks have been stopped + let io = if Rc::strong_count(this.inner.io.as_ref().unwrap()) == 1 { + if let Ok(io) = Rc::try_unwrap(this.inner.io.take().unwrap()) { + io.into_inner() + } else { + return Poll::Ready(Err(DispatchError::InternalError)); + } + } else { + // wait next task stop + this.inner.state.dsp_register_task(cx.waker()); + return Poll::Pending; + }; + log::trace!("initate upgrade handling"); + + let req = req.take().unwrap(); + *this.st = State::Call; + this.inner.state.reset_io_stop(); + + // Handle UPGRADE request + this.call.set(CallState::Upgrade { + fut: this.inner.config.upgrade.as_ref().unwrap().call(( + req, + io, + this.inner.state.clone(), + this.inner.codec.clone(), + )), + }); + } // prepare to shutdown State::Stop => { this.inner.state.shutdown_io(); @@ -426,7 +463,7 @@ where } } -impl DispatcherInner +impl DispatcherInner where S: Service, S::Error: ResponseError + 'static, @@ -442,8 +479,8 @@ where fn reset_keepalive(&mut self) { // re-register keep-alive if self.flags.contains(Flags::KEEPALIVE) { - let expire = - self.config.timer_h1.now() + Duration::from_secs(self.config.keep_alive); + let expire = self.config.timer_h1.now() + + time::Duration::from_secs(self.config.keep_alive); self.config .timer_h1 .register(expire, self.expire, &self.state); @@ -512,18 +549,25 @@ where fn send_payload( &mut self, item: Option>>, - ) -> Option> { + ) -> WritePayloadStatus { match item { Some(Ok(item)) => { trace!("Got response chunk: {:?}", item.len()); - if let Err(err) = self + match self .state .write_item(Message::Chunk(Some(item)), &self.codec) { - self.error = Some(DispatchError::Encode(err)); - Some(State::Stop) - } else { - None + Err(err) => { + self.error = Some(DispatchError::Encode(err)); + WritePayloadStatus::Next(State::Stop) + } + Ok(has_space) => { + if has_space { + WritePayloadStatus::Continue + } else { + WritePayloadStatus::Pause + } + } } } None => { @@ -532,24 +576,24 @@ where self.state.write_item(Message::Chunk(None), &self.codec) { self.error = Some(DispatchError::Encode(err)); - Some(State::Stop) + WritePayloadStatus::Next(State::Stop) } else if self.payload.is_some() { - Some(State::ReadPayload) + WritePayloadStatus::Next(State::ReadPayload) } else { self.reset_keepalive(); - Some(State::ReadRequest) + WritePayloadStatus::Next(State::ReadRequest) } } Some(Err(e)) => { trace!("Error during response body poll: {:?}", e); self.error = Some(DispatchError::ResponsePayload(e)); - Some(State::Stop) + WritePayloadStatus::Next(State::Stop) } } } /// Process request's payload - fn poll_read_payload(&mut self, cx: &mut Context<'_>) -> PollPayloadStatus { + fn poll_read_payload(&mut self, cx: &mut Context<'_>) -> ReadPayloadStatus { // check if payload data is required if let Some(ref mut payload) = self.payload { match payload.1.poll_data_required(cx) { @@ -557,7 +601,7 @@ where // read request payload let mut updated = false; loop { - let item = self.state.with_read_buf(|buf| payload.0.decode(buf)); + let item = self.state.decode_item(&payload.0); match item { Ok(Some(PayloadItem::Chunk(chunk))) => { updated = true; @@ -567,40 +611,434 @@ where payload.1.feed_eof(); self.payload = None; if !updated { - return PollPayloadStatus::Done; + return ReadPayloadStatus::Done; } break; } Ok(None) => { - self.state.dsp_read_more_data(cx.waker()); - break; + if self.state.is_io_err() { + payload.1.set_error(PayloadError::EncodingCorrupted); + self.payload = None; + self.error = Some(ParseError::Incomplete.into()); + return ReadPayloadStatus::Dropped; + } else { + self.state.dsp_read_more_data(cx.waker()); + break; + } } Err(e) => { payload.1.set_error(PayloadError::EncodingCorrupted); self.payload = None; self.error = Some(DispatchError::Parse(e)); - return PollPayloadStatus::Dropped; + return ReadPayloadStatus::Dropped; } } } if updated { - PollPayloadStatus::Updated + ReadPayloadStatus::Updated } else { - PollPayloadStatus::Pending + ReadPayloadStatus::Pending } } - PayloadStatus::Pause => PollPayloadStatus::Pending, + PayloadStatus::Pause => ReadPayloadStatus::Pending, PayloadStatus::Dropped => { // service call is not interested in payload // wait until future completes and then close // connection self.payload = None; self.error = Some(DispatchError::PayloadIsNotConsumed); - PollPayloadStatus::Dropped + ReadPayloadStatus::Dropped } } } else { - PollPayloadStatus::Done + ReadPayloadStatus::Done } } } + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + use std::{io, sync::Arc}; + + use bytes::{Bytes, BytesMut}; + use futures::future::{lazy, ok, FutureExt}; + use futures::StreamExt; + use rand::Rng; + + use super::*; + use crate::codec::Decoder; + use crate::http::config::{DispatcherConfig, ServiceConfig}; + use crate::http::h1::{ClientCodec, ExpectHandler, UpgradeHandler}; + use crate::http::{body, Request, ResponseHead, StatusCode}; + use crate::rt::time::delay_for; + use crate::service::IntoService; + use crate::testing::Io; + + const BUFFER_SIZE: usize = 32_768; + + /// Create http/1 dispatcher. + pub(crate) fn h1( + stream: Io, + service: F, + ) -> Dispatcher> + where + F: IntoService, + S: Service, + S::Error: ResponseError + 'static, + S::Response: Into>, + B: MessageBody, + { + Dispatcher::new( + stream, + Rc::new(DispatcherConfig::new( + ServiceConfig::default(), + service.into_service(), + ExpectHandler, + None, + )), + None, + None, + ) + } + + pub(crate) fn spawn_h1(stream: Io, service: F) + where + F: IntoService, + S: Service + 'static, + S::Error: ResponseError, + S::Response: Into>, + B: MessageBody + 'static, + { + crate::rt::spawn( + Dispatcher::>::new( + stream, + Rc::new(DispatcherConfig::new( + ServiceConfig::default(), + service.into_service(), + ExpectHandler, + None, + )), + None, + None, + ), + ); + } + + fn load(decoder: &mut ClientCodec, buf: &mut BytesMut) -> ResponseHead { + decoder.decode(buf).unwrap().unwrap() + } + + #[ntex_rt::test] + async fn test_req_parse_err() { + let (client, server) = Io::create(); + client.remote_buffer_cap(1024); + client.write("GET /test HTTP/1\r\n\r\n"); + + let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish())); + delay_for(time::Duration::from_millis(50)).await; + + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); + assert!(!h1.inner.state.is_open()); + delay_for(time::Duration::from_millis(50)).await; + + client + .local_buffer(|buf| assert_eq!(&buf[..26], b"HTTP/1.1 400 Bad Request\r\n")); + + client.close().await; + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); + // assert!(h1.inner.flags.contains(Flags::SHUTDOWN_IO)); + assert!(h1.inner.state.is_io_err()); + } + + #[ntex_rt::test] + async fn test_pipeline() { + let (client, server) = Io::create(); + client.remote_buffer_cap(4096); + let mut decoder = ClientCodec::default(); + spawn_h1(server, |_| ok::<_, io::Error>(Response::Ok().finish())); + + client.write("GET /test HTTP/1.1\r\n\r\n"); + + let mut buf = client.read().await.unwrap(); + assert!(load(&mut decoder, &mut buf).status.is_success()); + assert!(!client.is_server_dropped()); + + client.write("GET /test HTTP/1.1\r\n\r\n"); + client.write("GET /test HTTP/1.1\r\n\r\n"); + + let mut buf = client.read().await.unwrap(); + assert!(load(&mut decoder, &mut buf).status.is_success()); + assert!(load(&mut decoder, &mut buf).status.is_success()); + assert!(decoder.decode(&mut buf).unwrap().is_none()); + assert!(!client.is_server_dropped()); + + client.close().await; + assert!(client.is_server_dropped()); + } + + #[ntex_rt::test] + async fn test_pipeline_with_payload() { + let (client, server) = Io::create(); + client.remote_buffer_cap(4096); + let mut decoder = ClientCodec::default(); + spawn_h1(server, |mut req: Request| async move { + let mut p = req.take_payload(); + while let Some(_) = p.next().await {} + Ok::<_, io::Error>(Response::Ok().finish()) + }); + + client.write("GET /test HTTP/1.1\r\ncontent-length: 5\r\n\r\n"); + delay_for(time::Duration::from_millis(50)).await; + client.write("xxxxx"); + + let mut buf = client.read().await.unwrap(); + assert!(load(&mut decoder, &mut buf).status.is_success()); + assert!(!client.is_server_dropped()); + + client.write("GET /test HTTP/1.1\r\n\r\n"); + + let mut buf = client.read().await.unwrap(); + assert!(load(&mut decoder, &mut buf).status.is_success()); + assert!(decoder.decode(&mut buf).unwrap().is_none()); + assert!(!client.is_server_dropped()); + + client.close().await; + assert!(client.is_server_dropped()); + } + + #[ntex_rt::test] + async fn test_pipeline_with_delay() { + let (client, server) = Io::create(); + client.remote_buffer_cap(4096); + let mut decoder = ClientCodec::default(); + spawn_h1(server, |_| async { + delay_for(time::Duration::from_millis(100)).await; + Ok::<_, io::Error>(Response::Ok().finish()) + }); + + client.write("GET /test HTTP/1.1\r\n\r\n"); + + let mut buf = client.read().await.unwrap(); + assert!(load(&mut decoder, &mut buf).status.is_success()); + assert!(!client.is_server_dropped()); + + client.write("GET /test HTTP/1.1\r\n\r\n"); + client.write("GET /test HTTP/1.1\r\n\r\n"); + delay_for(time::Duration::from_millis(50)).await; + client.write("GET /test HTTP/1.1\r\n\r\n"); + + let mut buf = client.read().await.unwrap(); + assert!(load(&mut decoder, &mut buf).status.is_success()); + + let mut buf = client.read().await.unwrap(); + assert!(load(&mut decoder, &mut buf).status.is_success()); + assert!(decoder.decode(&mut buf).unwrap().is_none()); + assert!(!client.is_server_dropped()); + + buf.extend(client.read().await.unwrap()); + assert!(load(&mut decoder, &mut buf).status.is_success()); + assert!(decoder.decode(&mut buf).unwrap().is_none()); + assert!(!client.is_server_dropped()); + + client.close().await; + assert!(client.is_server_dropped()); + } + + #[ntex_rt::test] + /// if socket is disconnected, h1 dispatcher does not process any data + // /// h1 dispatcher still processes all incoming requests + // /// but it does not write any data to socket + async fn test_write_disconnected() { + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + let (client, server) = Io::create(); + spawn_h1(server, move |_| { + num2.fetch_add(1, Ordering::Relaxed); + ok::<_, io::Error>(Response::Ok().finish()) + }); + + client.remote_buffer_cap(1024); + client.write("GET /test HTTP/1.1\r\n\r\n"); + client.write("GET /test HTTP/1.1\r\n\r\n"); + client.write("GET /test HTTP/1.1\r\n\r\n"); + client.close().await; + assert!(client.is_server_dropped()); + assert!(client.read_any().is_empty()); + + // only first request get handled + assert_eq!(num.load(Ordering::Relaxed), 0); + } + + #[ntex_rt::test] + async fn test_read_large_message() { + let (client, server) = Io::create(); + client.remote_buffer_cap(4096); + + let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish())); + let mut decoder = ClientCodec::default(); + + let data = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(70_000) + .map(char::from) + .collect::(); + client.write("GET /test HTTP/1.1\r\nContent-Length: "); + client.write(data); + delay_for(time::Duration::from_millis(50)).await; + + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); + delay_for(time::Duration::from_millis(50)).await; + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); + assert!(!h1.inner.state.is_open()); + + let mut buf = client.read().await.unwrap(); + assert_eq!(load(&mut decoder, &mut buf).status, StatusCode::BAD_REQUEST); + } + + #[ntex_rt::test] + async fn test_read_backpressure() { + let mark = Arc::new(AtomicBool::new(false)); + let mark2 = mark.clone(); + + let (client, server) = Io::create(); + client.remote_buffer_cap(4096); + spawn_h1(server, move |mut req: Request| { + let m = mark2.clone(); + async move { + // read one chunk + let mut pl = req.take_payload(); + let _ = pl.next().await.unwrap().unwrap(); + m.store(true, Ordering::Relaxed); + // sleep + delay_for(time::Duration::from_secs(999_999)).await; + Ok::<_, io::Error>(Response::Ok().finish()) + } + }); + + client.write("GET /test HTTP/1.1\r\nContent-Length: 1048576\r\n\r\n"); + delay_for(time::Duration::from_millis(50)).await; + + // buf must be consumed + assert_eq!(client.remote_buffer(|buf| buf.len()), 0); + + // io should be drained only by no more than MAX_BUFFER_SIZE + let random_bytes: Vec = + (0..1_048_576).map(|_| rand::random::()).collect(); + client.write(random_bytes); + + delay_for(time::Duration::from_millis(50)).await; + assert!(client.remote_buffer(|buf| buf.len()) > 1_048_576 - BUFFER_SIZE * 3); + assert!(mark.load(Ordering::Relaxed)); + } + + #[ntex_rt::test] + async fn test_write_backpressure() { + std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace"); + env_logger::init(); + + let num = Arc::new(AtomicUsize::new(0)); + let num2 = num.clone(); + + struct Stream(Arc); + + impl body::MessageBody for Stream { + fn size(&self) -> body::BodySize { + body::BodySize::Stream + } + fn poll_next_chunk( + &mut self, + _: &mut Context<'_>, + ) -> Poll>>> { + let data = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(65_536) + .map(char::from) + .collect::(); + self.0.fetch_add(data.len(), Ordering::Relaxed); + + Poll::Ready(Some(Ok(Bytes::from(data)))) + } + } + + let (client, server) = Io::create(); + let mut h1 = h1(server, move |_| { + let n = num2.clone(); + async move { Ok::<_, io::Error>(Response::Ok().message_body(Stream(n.clone()))) } + .boxed_local() + }); + let state = h1.inner.state.clone(); + + // do not allow to write to socket + client.remote_buffer_cap(0); + client.write("GET /test HTTP/1.1\r\n\r\n"); + delay_for(time::Duration::from_millis(50)).await; + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); + + // buf must be consumed + assert_eq!(client.remote_buffer(|buf| buf.len()), 0); + + // amount of generated data + assert_eq!(num.load(Ordering::Relaxed), 65_536); + + // response message + chunking encoding + assert_eq!(state.with_write_buf(|buf| buf.len()), 65629); + + client.remote_buffer_cap(65536); + delay_for(time::Duration::from_millis(50)).await; + assert_eq!(state.with_write_buf(|buf| buf.len()), 93); + + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); + assert_eq!(num.load(Ordering::Relaxed), 65_536 * 2); + } + + #[ntex_rt::test] + async fn test_disconnect_during_response_body_pending() { + struct Stream(bool); + + impl body::MessageBody for Stream { + fn size(&self) -> body::BodySize { + body::BodySize::Sized(2048) + } + fn poll_next_chunk( + &mut self, + _: &mut Context<'_>, + ) -> Poll>>> { + if self.0 { + Poll::Pending + } else { + self.0 = true; + let data = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(1024) + .map(char::from) + .collect::(); + Poll::Ready(Some(Ok(Bytes::from(data)))) + } + } + } + + let (client, server) = Io::create(); + client.remote_buffer_cap(4096); + let mut h1 = h1(server, |_| { + ok::<_, io::Error>(Response::Ok().message_body(Stream(false))) + }); + + client.write("GET /test HTTP/1.1\r\n\r\n"); + delay_for(time::Duration::from_millis(50)).await; + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); + + // http message must be consumed + assert_eq!(client.remote_buffer(|buf| buf.len()), 0); + + let mut decoder = ClientCodec::default(); + let mut buf = client.read().await.unwrap(); + assert!(load(&mut decoder, &mut buf).status.is_success()); + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); + + client.close().await; + delay_for(time::Duration::from_millis(50)).await; + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); + } +} diff --git a/ntex/src/http/h1/service.rs b/ntex/src/http/h1/service.rs index bbcd701e..0be924b6 100644 --- a/ntex/src/http/h1/service.rs +++ b/ntex/src/http/h1/service.rs @@ -68,7 +68,11 @@ where X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, - U: ServiceFactory, + U: ServiceFactory< + Config = (), + Request = (Request, TcpStream, IoState, Codec), + Response = (), + >, U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, @@ -112,7 +116,7 @@ mod openssl { X::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, IoState, Codec), + Request = (Request, SslStream, IoState, Codec), Response = (), >, U::Error: fmt::Display + Error + 'static, @@ -166,7 +170,7 @@ mod rustls { X::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, IoState, Codec), + Request = (Request, TlsStream, IoState, Codec), Response = (), >, U::Error: fmt::Display + Error + 'static, @@ -228,7 +232,7 @@ where pub fn upgrade(self, upgrade: Option) -> H1Service where - U1: ServiceFactory, + U1: ServiceFactory, U1::Error: fmt::Display + Error + 'static, U1::InitError: fmt::Debug, U1::Future: 'static, @@ -267,7 +271,11 @@ where X::Error: ResponseError + 'static, X::InitError: fmt::Debug, X::Future: 'static, - U: ServiceFactory, + U: ServiceFactory< + Config = (), + Request = (Request, T, IoState, Codec), + Response = (), + >, U::Error: fmt::Display + Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, @@ -331,13 +339,13 @@ where B: MessageBody, X: Service, X::Error: ResponseError + 'static, - U: Service, + U: Service, U::Error: fmt::Display + Error + 'static, { type Request = (T, Option); type Response = (); type Error = DispatchError; - type Future = Dispatcher; + type Future = Dispatcher; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { let cfg = self.config.as_ref(); diff --git a/ntex/src/http/h1/upgrade.rs b/ntex/src/http/h1/upgrade.rs index 556b2546..9c4984b0 100644 --- a/ntex/src/http/h1/upgrade.rs +++ b/ntex/src/http/h1/upgrade.rs @@ -11,7 +11,7 @@ pub struct UpgradeHandler(PhantomData); impl ServiceFactory for UpgradeHandler { type Config = (); - type Request = (Request, State, Codec); + type Request = (Request, T, State, Codec); type Response = (); type Error = io::Error; type Service = UpgradeHandler; @@ -25,7 +25,7 @@ impl ServiceFactory for UpgradeHandler { } impl Service for UpgradeHandler { - type Request = (Request, State, Codec); + type Request = (Request, T, State, Codec); type Response = (); type Error = io::Error; type Future = Ready>; diff --git a/ntex/src/http/service.rs b/ntex/src/http/service.rs index 2886290c..08899f5e 100644 --- a/ntex/src/http/service.rs +++ b/ntex/src/http/service.rs @@ -126,7 +126,7 @@ where where U1: ServiceFactory< Config = (), - Request = (Request, State, h1::Codec), + Request = (Request, T, State, h1::Codec), Response = (), >, U1::Error: fmt::Display + error::Error + 'static, @@ -167,7 +167,11 @@ where X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, - U: ServiceFactory, + U: ServiceFactory< + Config = (), + Request = (Request, TcpStream, State, h1::Codec), + Response = (), + >, U::Error: fmt::Display + error::Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, @@ -213,7 +217,7 @@ mod openssl { ::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, State, h1::Codec), + Request = (Request, SslStream, State, h1::Codec), Response = (), >, U::Error: fmt::Display + error::Error + 'static, @@ -278,7 +282,7 @@ mod rustls { ::Future: 'static, U: ServiceFactory< Config = (), - Request = (Request, State, h1::Codec), + Request = (Request, TlsStream, State, h1::Codec), Response = (), >, U::Error: fmt::Display + error::Error + 'static, @@ -342,7 +346,11 @@ where X::InitError: fmt::Debug, X::Future: 'static, ::Future: 'static, - U: ServiceFactory, + U: ServiceFactory< + Config = (), + Request = (Request, T, State, h1::Codec), + Response = (), + >, U::Error: fmt::Display + error::Error + 'static, U::InitError: fmt::Debug, U::Future: 'static, @@ -410,7 +418,7 @@ where B: MessageBody + 'static, X: Service, X::Error: ResponseError + 'static, - U: Service, + U: Service, U::Error: fmt::Display + error::Error + 'static, { type Request = (T, Protocol, Option); @@ -513,6 +521,7 @@ pin_project_lite::pin_project! { T: AsyncRead, T: AsyncWrite, T: Unpin, + T: 'static, S: Service, S::Error: ResponseError, S::Error: 'static, @@ -523,7 +532,7 @@ pin_project_lite::pin_project! { X: Service, X::Error: ResponseError, X::Error: 'static, - U: Service, + U: Service, U::Error: fmt::Display, U::Error: error::Error, U::Error: 'static, @@ -543,16 +552,17 @@ pin_project_lite::pin_project! { T: AsyncRead, T: AsyncWrite, T: Unpin, + T: 'static, B: MessageBody, X: Service, X::Error: ResponseError, X::Error: 'static, - U: Service, + U: Service, U::Error: fmt::Display, U::Error: error::Error, U::Error: 'static, { - H1 { #[pin] fut: h1::Dispatcher }, + H1 { #[pin] fut: h1::Dispatcher }, H2 { fut: Dispatcher }, H2Handshake { data: Option<( @@ -567,7 +577,7 @@ pin_project_lite::pin_project! { impl Future for HttpServiceHandlerResponse where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin + 'static, S: Service, S::Error: ResponseError + 'static, S::Future: 'static, @@ -575,7 +585,7 @@ where B: MessageBody, X: Service, X::Error: ResponseError + 'static, - U: Service, + U: Service, U::Error: fmt::Display + error::Error + 'static, { type Output = Result<(), DispatchError>; diff --git a/ntex/tests/http_awc_ws.rs b/ntex/tests/http_awc_ws.rs index c8ad2b42..380c6a42 100644 --- a/ntex/tests/http_awc_ws.rs +++ b/ntex/tests/http_awc_ws.rs @@ -8,12 +8,12 @@ use ntex::framed::{DispatchItem, Dispatcher, State}; use ntex::http::test::server as test_server; use ntex::http::ws::handshake_response; use ntex::http::{body::BodySize, h1, HttpService, Request, Response}; +use ntex::rt::net::TcpStream; use ntex::ws; async fn ws_service( msg: DispatchItem, ) -> Result, io::Error> { - println!("TEST: {:?}", msg); let msg = match msg { DispatchItem::Item(msg) => match msg { ws::Frame::Ping(msg) => ws::Message::Pong(msg), @@ -31,33 +31,33 @@ async fn ws_service( #[ntex::test] async fn test_simple() { - std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace"); - env_logger::init(); - let mut srv = test_server(|| { HttpService::build() - .upgrade(|(req, state, mut codec): (Request, State, h1::Codec)| { - async move { - let res = handshake_response(req.head()).finish(); + .upgrade( + |(req, io, state, mut codec): (Request, TcpStream, State, h1::Codec)| { + async move { + let res = handshake_response(req.head()).finish(); - // send handshake respone - state - .write_item( - h1::Message::Item((res.drop_body(), BodySize::None)), - &mut codec, + // send handshake respone + state + .write_item( + h1::Message::Item((res.drop_body(), BodySize::None)), + &mut codec, + ) + .unwrap(); + + // start websocket service + Dispatcher::new( + io, + ws::Codec::default(), + state, + ws_service, + Default::default(), ) - .unwrap(); - - // start websocket service - Dispatcher::from_state( - ws::Codec::default(), - state, - ws_service, - Default::default(), - ) - .await - } - }) + .await + } + }, + ) .finish(|_| ok::<_, io::Error>(Response::NotFound())) .tcp() }); diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index a457ce9a..be15d911 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -85,6 +85,9 @@ async fn test_expect_continue() { #[ntex::test] async fn test_expect_continue_h1() { + std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace"); + env_logger::init(); + let srv = test_server(|| { HttpService::build() .expect(fn_service(|req: Request| { @@ -115,7 +118,9 @@ async fn test_expect_continue_h1() { let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n"); let mut data = String::new(); + println!("1-------------------"); let _ = stream.read_to_string(&mut data); + println!("2-------------------"); assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n")); } diff --git a/ntex/tests/http_ws.rs b/ntex/tests/http_ws.rs index 94b396d3..be4cebe4 100644 --- a/ntex/tests/http_ws.rs +++ b/ntex/tests/http_ws.rs @@ -1,20 +1,21 @@ use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use std::{cell::Cell, io, pin::Pin}; +use std::{cell::Cell, io, marker::PhantomData, pin::Pin}; use bytes::Bytes; use futures::{future, Future, SinkExt, StreamExt}; +use ntex::codec::{AsyncRead, AsyncWrite}; use ntex::framed::{DispatchItem, Dispatcher, State, Timer}; use ntex::http::{body, h1, test, ws::handshake, HttpService, Request, Response}; use ntex::service::{fn_factory, Service}; use ntex::ws; -struct WsService(Arc>>); +struct WsService(Arc>>, PhantomData); -impl WsService { +impl WsService { fn new() -> Self { - WsService(Arc::new(Mutex::new(Cell::new(false)))) + WsService(Arc::new(Mutex::new(Cell::new(false))), PhantomData) } fn set_polled(&self) { @@ -26,14 +27,17 @@ impl WsService { } } -impl Clone for WsService { +impl Clone for WsService { fn clone(&self) -> Self { - WsService(self.0.clone()) + WsService(self.0.clone(), PhantomData) } } -impl Service for WsService { - type Request = (Request, State, h1::Codec); +impl Service for WsService +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Request = (Request, T, State, h1::Codec); type Response = (); type Error = io::Error; type Future = Pin>>>; @@ -43,7 +47,7 @@ impl Service for WsService { Poll::Ready(Ok(())) } - fn call(&self, (req, state, mut codec): Self::Request) -> Self::Future { + fn call(&self, (req, io, state, mut codec): Self::Request) -> Self::Future { let fut = async move { let res = handshake(req.head()).unwrap().message_body(()); @@ -51,7 +55,7 @@ impl Service for WsService { .write_item((res, body::BodySize::None).into(), &mut codec) .unwrap(); - Dispatcher::from_state(ws::Codec::new(), state, service, Timer::default()) + Dispatcher::new(io, ws::Codec::new(), state, service, Timer::default()) .await .map_err(|_| panic!()) };