diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index 24c02d56..91376c91 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.1.0] - 2024-07-30 + +* Better handling for connection upgrade #385 + ## [2.0.3] - 2024-06-27 * Re-export server signals api diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index a02d8a62..b3d4cb93 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "2.0.3" +version = "2.1.0" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -68,7 +68,7 @@ ntex-bytes = "0.1.27" ntex-server = "2.1" ntex-h2 = "1.0" ntex-rt = "0.4.13" -ntex-io = "2.0" +ntex-io = "2.1" ntex-net = "2.0" ntex-tls = "2.0" diff --git a/ntex/examples/basic.rs b/ntex/examples/basic.rs index d3e84fca..e4599b2a 100644 --- a/ntex/examples/basic.rs +++ b/ntex/examples/basic.rs @@ -1,44 +1,29 @@ -use ntex::http; -use ntex::web::{self, middleware, App, HttpRequest, HttpResponse, HttpServer}; +use ntex::web; -#[web::get("/resource1/{name}/index.html")] -async fn index(req: HttpRequest, name: web::types::Path) -> String { - println!("REQ: {:?}", req); - format!("Hello: {}!\r\n", name) +#[derive(serde::Deserialize)] +struct Info { + username: String, } -async fn index_async(req: HttpRequest) -> &'static str { - println!("REQ: {:?}", req); - "Hello world!\r\n" -} - -#[web::get("/")] -async fn no_params() -> &'static str { - "Hello world!\r\n" +async fn submit(info: web::types::Json) -> Result { + Ok(format!("Welcome {}!", info.username)) } #[ntex::main] async fn main() -> std::io::Result<()> { - std::env::set_var("RUST_LOG", "ntex=trace"); + std::env::set_var("RUST_LOG", "trace"); env_logger::init(); + web::HttpServer::new(|| { + let json_config = web::types::JsonConfig::default().limit(4096); - HttpServer::new(|| { - App::new() - .wrap(middleware::Logger::default()) - .service((index, no_params)) - .service( - web::resource("/resource2/index.html") - .wrap(middleware::DefaultHeaders::new().header("X-Version-R2", "0.3")) - .default_service( - web::route().to(|| async { HttpResponse::MethodNotAllowed() }), - ) - .route(web::get().to(index_async)), - ) - .service(web::resource("/test1.html").to(|| async { "Test\r\n" })) + web::App::new().service( + web::resource("/") + .state(json_config) + .route(web::post().to(submit)), + ) }) - .bind("0.0.0.0:8081")? - .workers(4) - .keep_alive(http::KeepAlive::Disabled) + .bind(("127.0.0.1", 8080))? + .workers(1) .run() .await } diff --git a/ntex/src/http/h1/codec.rs b/ntex/src/http/h1/codec.rs index 8a9e431f..eba02f04 100644 --- a/ntex/src/http/h1/codec.rs +++ b/ntex/src/http/h1/codec.rs @@ -99,10 +99,6 @@ impl Codec { self.ctype.get() == ConnectionType::KeepAlive } - pub(super) fn set_ctype(&self, ctype: ConnectionType) { - self.ctype.set(ctype) - } - #[inline] #[doc(hidden)] pub fn set_date_header(&self, dst: &mut BytesMut) { @@ -115,10 +111,11 @@ impl Codec { self.flags.set(flags); } - pub(super) fn unset_streaming(&self) { + pub(super) fn reset_upgrade(&self) { let mut flags = self.flags.get(); flags.remove(Flags::STREAM); self.flags.set(flags); + self.ctype.set(ConnectionType::Close); } } diff --git a/ntex/src/http/h1/control.rs b/ntex/src/http/h1/control.rs index b724dcce..f8844aaf 100644 --- a/ntex/src/http/h1/control.rs +++ b/ntex/src/http/h1/control.rs @@ -1,8 +1,8 @@ -use std::{fmt, future::Future, io}; +use std::{fmt, future::Future, io, rc::Rc}; use crate::http::message::CurrentIo; use crate::http::{body::Body, h1::Codec, Request, Response, ResponseError}; -use crate::io::{Filter, Io, IoBoxed}; +use crate::io::{Filter, Io, IoBoxed, IoRef}; pub enum Control { /// New request is loaded @@ -46,12 +46,8 @@ pub(super) enum ControlResult { Upgrade(Request), /// forward request to publish service Publish(Request), - /// forward request to publish service - PublishUpgrade(Request), /// send response Response(Response<()>, Body), - /// send response - ResponseWithIo(Response<()>, Body, IoBoxed), /// drop connection Stop, } @@ -72,7 +68,7 @@ impl Control { Control::NewRequest(NewRequest(req)) } - pub(super) fn upgrade(req: Request, io: Io, codec: Codec) -> Self { + pub(super) fn upgrade(req: Request, io: Rc>, codec: Codec) -> Self { Control::Upgrade(Upgrade { req, io, codec }) } @@ -188,10 +184,34 @@ impl NewRequest { pub struct Upgrade { req: Request, - io: Io, + io: Rc>, codec: Codec, } +struct RequestIoAccess { + io: Rc>, + codec: Codec, +} + +impl fmt::Debug for RequestIoAccess { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RequestIoAccess") + .field("io", self.io.as_ref()) + .field("codec", &self.codec) + .finish() + } +} + +impl crate::http::message::IoAccess for RequestIoAccess { + fn get(&self) -> Option<&IoRef> { + Some(self.io.as_ref()) + } + + fn take(&self) -> Option<(IoBoxed, Codec)> { + Some((self.io.take().into(), self.codec.clone())) + } +} + impl Upgrade { #[inline] /// Returns reference to Io @@ -215,12 +235,14 @@ impl Upgrade { /// Ack upgrade request and continue handling process pub fn ack(mut self) -> ControlAck { // Move io into request - let io: IoBoxed = self.io.into(); - io.stop_timer(); - self.req.head_mut().io = CurrentIo::new(io, self.codec); + let io = Rc::new(RequestIoAccess { + io: self.io, + codec: self.codec, + }); + self.req.head_mut().io = CurrentIo::new(io); ControlAck { - result: ControlResult::PublishUpgrade(self.req), + result: ControlResult::Publish(self.req), flags: ControlFlags::DISCONNECT, } } @@ -232,8 +254,9 @@ impl Upgrade { H: FnOnce(Request, Io, Codec) -> R + 'static, R: Future, { + let io = self.io.take(); let _ = crate::rt::spawn(async move { - let _ = f(self.req, self.io, self.codec).await; + let _ = f(self.req, io, self.codec).await; }); ControlAck { result: ControlResult::Stop, @@ -248,7 +271,7 @@ impl Upgrade { let (res, body) = res.into_parts(); ControlAck { - result: ControlResult::ResponseWithIo(res, body.into(), self.io.into()), + result: ControlResult::Response(res, body.into()), flags: ControlFlags::DISCONNECT, } } @@ -259,7 +282,7 @@ impl Upgrade { let (res, body) = res.into_parts(); ControlAck { - result: ControlResult::ResponseWithIo(res, body.into(), self.io.into()), + result: ControlResult::Response(res, body.into()), flags: ControlFlags::DISCONNECT, } } diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index 664b0de8..e7e59cdf 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1,14 +1,14 @@ //! HTTP/1 protocol dispatcher use std::{error, future, io, marker, pin::Pin, rc::Rc, task::Context, task::Poll}; -use crate::io::{Decoded, Filter, Io, IoBoxed, IoStatusUpdate, RecvError}; +use crate::io::{Decoded, Filter, Io, IoStatusUpdate, RecvError}; use crate::service::{PipelineCall, Service}; use crate::time::Seconds; use crate::util::{ready, Either}; use crate::http::body::{BodySize, MessageBody, ResponseBody}; use crate::http::error::{PayloadError, ResponseError}; -use crate::http::message::{ConnectionType, CurrentIo}; +use crate::http::message::CurrentIo; use crate::http::{self, config::DispatcherConfig, request::Request, response::Response}; use super::control::{Control, ControlAck, ControlFlags, ControlResult}; @@ -19,8 +19,6 @@ use super::{codec::Codec, Message, ProtocolError}; bitflags::bitflags! { #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct Flags: u8 { - /// Upgrade hnd - const UPGRADE = 0b0000_0001; /// Stopping const SENDPAYLOAD_AND_STOP = 0b0000_0010; /// Complete operation and disconnect @@ -65,18 +63,13 @@ where SendPayload { body: ResponseBody, }, - SendPayloadAndStop { - body: ResponseBody, - io: IoBoxed, - }, Stop { fut: Option>>, - io: Option, }, } struct DispatcherInner { - io: Io, + io: Rc>, flags: Flags, codec: Codec, config: Rc>, @@ -112,10 +105,10 @@ where Dispatcher { st: State::ReadRequest, inner: DispatcherInner { - io, flags, codec, config, + io: Rc::new(io), payload: None, read_remains: 0, read_consumed: 0, @@ -148,20 +141,10 @@ where State::CallPublish { fut } => match Pin::new(fut).poll(cx) { Poll::Ready(Ok(res)) => { let (res, body) = res.into().into_parts(); - if inner.flags.contains(Flags::UPGRADE) { - inner.send_response_to(res, body, None) - } else { - inner.send_response(res, body) - } + inner.send_response(res, body) } Poll::Ready(Err(err)) => inner.control(Control::err(err)), - Poll::Pending => { - if !inner.flags.contains(Flags::UPGRADE) { - ready!(inner.poll_request(cx)) - } else { - return Poll::Pending; - } - } + Poll::Pending => ready!(inner.poll_request(cx)), }, // handle control service responses State::CallControl { fut } => match Pin::new(fut).poll(cx) { @@ -181,16 +164,9 @@ where match result { ControlResult::Publish(req) => inner.publish(req), - ControlResult::PublishUpgrade(req) => { - inner.flags.insert(Flags::UPGRADE); - inner.publish(req) - } ControlResult::Response(res, body) => { inner.send_response(res, body.into()) } - ControlResult::ResponseWithIo(res, body, io) => { - inner.send_response_to(res, body.into(), Some(io)) - } ControlResult::Expect(req) => { inner.control(Control::expect(req)) } @@ -219,27 +195,17 @@ where State::SendPayload { body } => { ready!(inner.poll_send_payload(cx, body)) } - // send response body - State::SendPayloadAndStop { body, io } => { - ready!(inner.poll_send_payload_to(cx, body, io)) - } // shutdown io - State::Stop { fut, io } => { + State::Stop { fut } => { if let Some(ref mut f) = fut { let _ = ready!(Pin::new(f).poll(cx)); fut.take(); } log::debug!("{}: Dispatcher is stopped", inner.io.tag()); + inner.io.stop_timer(); return Poll::Ready( - if let Some(io) = io { - io.stop_timer(); - ready!(io.poll_shutdown(cx)) - } else { - inner.io.stop_timer(); - ready!(inner.io.poll_shutdown(cx)) - } - .map_err(From::from), + ready!(inner.io.poll_shutdown(cx)).map_err(From::from), ); } } @@ -438,76 +404,6 @@ where } } - fn send_response_to( - &mut self, - res: Response<()>, - body: ResponseBody, - io: Option, - ) -> State { - let io = if let Some(io) = io { - io - } else if let Some((io, codec)) = res.head().io.take() { - self.codec = codec; - io - } else { - log::trace!("Handler service consumed io, stop"); - return self.stop(); - }; - - self.codec.set_ctype(ConnectionType::Close); - self.codec.unset_streaming(); - - if io - .encode(Message::Item((res, body.size())), &self.codec) - .is_ok() - { - match body.size() { - BodySize::None | BodySize::Empty => self.stop_io(io), - _ => State::SendPayloadAndStop { io, body }, - } - } else { - self.stop_io(io) - } - } - - /// send response body to specified io - fn poll_send_payload_to( - &mut self, - cx: &mut Context<'_>, - body: &mut ResponseBody, - io: &mut IoBoxed, - ) -> Poll> { - if io.is_closed() { - return Poll::Ready(self.stop()); - } else if !self.flags.contains(Flags::SENDPAYLOAD_AND_STOP) { - if let Poll::Ready(Err(_)) = self._poll_request_payload(Some(io), cx) { - self.flags.insert(Flags::SENDPAYLOAD_AND_STOP); - } - } - - loop { - let _ = ready!(io.poll_flush(cx, false)); - match ready!(body.poll_next_chunk(cx)) { - Some(Ok(item)) => { - if let Err(e) = io.encode(Message::Chunk(Some(item)), &self.codec) { - log::trace!("{}: Cannot encode chunk: {:?}", io.tag(), e); - } else { - continue; - } - } - None => { - if let Err(e) = io.encode(Message::Chunk(None), &self.codec) { - log::trace!("{}: Cannot encode payload eof: {:?}", io.tag(), e); - } - } - Some(Err(e)) => { - log::trace!("{}: error during response body poll: {:?}", io.tag(), e); - } - } - return Poll::Ready(self.stop_io(io.take())); - } - } - /// we might need to read more data into a request payload /// (ie service future can wait for payload data) fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -823,20 +719,13 @@ where } fn ctl_upgrade(&mut self, req: Request) -> State { - let msg = Control::upgrade(req, self.io.take(), self.codec.clone()); + self.codec.reset_upgrade(); + let msg = Control::upgrade(req, self.io.clone(), self.codec.clone()); self.control(msg) } fn stop(&mut self) -> State { State::Stop { - io: None, - fut: Some(self.config.control.call_nowait(Control::closed())), - } - } - - fn stop_io(&mut self, io: IoBoxed) -> State { - State::Stop { - io: Some(io), fut: Some(self.config.control.call_nowait(Control::closed())), } } diff --git a/ntex/src/http/message.rs b/ntex/src/http/message.rs index db5adc41..2d99b315 100644 --- a/ntex/src/http/message.rs +++ b/ntex/src/http/message.rs @@ -39,26 +39,32 @@ pub(crate) trait Head: Default + 'static + fmt::Debug { #[derive(Clone, Debug)] pub(crate) enum CurrentIo { Ref(IoRef), - Io(Rc<(IoRef, RefCell>)>), + Io(Rc), None, } +pub(crate) trait IoAccess: fmt::Debug { + fn get(&self) -> Option<&IoRef>; + + fn take(&self) -> Option<(IoBoxed, Codec)>; +} + impl CurrentIo { - pub(crate) fn new(io: IoBoxed, codec: Codec) -> Self { - CurrentIo::Io(Rc::new((io.get_ref(), RefCell::new(Some((io, codec)))))) + pub(crate) fn new(io: Rc) -> Self { + CurrentIo::Io(io) } pub(crate) fn as_ref(&self) -> Option<&IoRef> { match self { CurrentIo::Ref(ref io) => Some(io), - CurrentIo::Io(ref io) => Some(&io.0), + CurrentIo::Io(ref io) => io.get(), CurrentIo::None => None, } } pub(crate) fn take(&self) -> Option<(IoBoxed, Codec)> { match self { - CurrentIo::Io(ref inner) => inner.1.borrow_mut().take(), + CurrentIo::Io(ref io) => io.take(), _ => None, } } @@ -215,16 +221,6 @@ impl RequestHead { pub fn remove_io(&mut self) { self.io = CurrentIo::None; } - - pub(crate) fn take_io_rc( - &self, - ) -> Option>)>> { - if let CurrentIo::Io(ref r) = self.io { - Some(r.clone()) - } else { - None - } - } } #[derive(Debug)] @@ -384,10 +380,6 @@ impl ResponseHead { self.flags.remove(Flags::NO_CHUNKING); } } - - pub(crate) fn set_io(&mut self, io: Rc<(IoRef, RefCell>)>) { - self.io = CurrentIo::Io(io) - } } impl Default for ResponseHead { diff --git a/ntex/src/http/test.rs b/ntex/src/http/test.rs index dab5840d..4ee578b0 100644 --- a/ntex/src/http/test.rs +++ b/ntex/src/http/test.rs @@ -239,6 +239,7 @@ where sys.run(move || { crate::server::build() .listen("test", tcp, move |_| factory())? + .set_tag("test", "HTTP-TEST-SRV") .workers(1) .disable_signals() .run(); diff --git a/ntex/src/web/response.rs b/ntex/src/web/response.rs index 17f10fb2..abd0d01b 100644 --- a/ntex/src/web/response.rs +++ b/ntex/src/web/response.rs @@ -124,10 +124,7 @@ impl WebResponse { } impl From for Response { - fn from(mut res: WebResponse) -> Response { - if let Some(io) = res.request.head().take_io_rc() { - res.response.head_mut().set_io(io); - } + fn from(res: WebResponse) -> Response { res.response } } diff --git a/ntex/tests/web_server.rs b/ntex/tests/web_server.rs index bc405288..aefe6682 100644 --- a/ntex/tests/web_server.rs +++ b/ntex/tests/web_server.rs @@ -1141,8 +1141,9 @@ async fn test_web_server() { system.stop(); } +/// Websocket connection, no ws handler and response contains payload #[ntex::test] -async fn web_no_ws_payload() { +async fn web_no_ws_with_response_payload() { let srv = test::server_with(test::config().h1(), || { App::new() .service(web::resource("/").route(web::get().to(move || async {