diff --git a/ntex-util/src/future/either.rs b/ntex-util/src/future/either.rs index af59bfa7..2c142dfb 100644 --- a/ntex-util/src/future/either.rs +++ b/ntex-util/src/future/either.rs @@ -89,6 +89,12 @@ where A: error::Error, B: error::Error, { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + match self { + Either::Left(a) => a.source(), + Either::Right(b) => b.source(), + } + } } impl fmt::Display for Either diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index 2961ebfa..642e4202 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [1.1.0] - 2024-02-07 + +* http: Add http/1 control service + +* http: Add http/2 control service + ## [1.0.0] - 2024-01-09 * web: Use async fn for Responder and Handler traits diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 90799fae..1ab1107d 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "1.0.0" +version = "1.1.0" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -52,13 +52,13 @@ ntex-codec = "0.6.2" ntex-connect = "1.0.0" ntex-http = "0.1.12" ntex-router = "0.5.3" -ntex-service = "2.0.0" +ntex-service = "2.0.1" ntex-macros = "0.1.3" -ntex-util = "1.0.0" +ntex-util = "1.0.1" ntex-bytes = "0.1.24" ntex-h2 = "0.5.0" ntex-rt = "0.4.11" -ntex-io = "1.0.0" +ntex-io = "1.0.1" ntex-tls = "1.0.0" ntex-tokio = { version = "0.4.0", optional = true } ntex-glommio = { version = "0.4.0", optional = true } diff --git a/ntex/src/http/body.rs b/ntex/src/http/body.rs index 02d47234..048854e7 100644 --- a/ntex/src/http/body.rs +++ b/ntex/src/http/body.rs @@ -70,6 +70,15 @@ impl ResponseBody { } } +impl From> for Body { + fn from(b: ResponseBody) -> Self { + match b { + ResponseBody::Body(b) => b, + ResponseBody::Other(b) => b, + } + } +} + impl From for ResponseBody { fn from(b: Body) -> Self { ResponseBody::Other(b) diff --git a/ntex/src/http/builder.rs b/ntex/src/http/builder.rs index 859cdba5..14839efb 100644 --- a/ntex/src/http/builder.rs +++ b/ntex/src/http/builder.rs @@ -1,32 +1,31 @@ use std::{error::Error, fmt, marker::PhantomData}; -use ntex_h2::{self as h2}; - use crate::http::body::MessageBody; -use crate::http::config::{KeepAlive, OnRequest, ServiceConfig}; -use crate::http::error::ResponseError; -use crate::http::h1::{Codec, ExpectHandler, H1Service, UpgradeHandler}; -use crate::http::h2::H2Service; -use crate::http::request::Request; -use crate::http::response::Response; -use crate::http::service::HttpService; -use crate::io::{Filter, Io, IoRef}; -use crate::service::{boxed, IntoService, IntoServiceFactory, Service, ServiceFactory}; -use crate::time::Seconds; +use crate::http::config::{KeepAlive, ServiceConfig}; +use crate::http::error::{H2Error, ResponseError}; +use crate::http::h1::{self, H1Service}; +use crate::http::h2::{self, H2Service}; +use crate::http::{request::Request, response::Response, service::HttpService}; +use crate::service::{IntoServiceFactory, ServiceFactory}; +use crate::{io::Filter, time::Seconds}; /// A http service builder /// /// This type can be used to construct an instance of `http service` through a /// builder-like pattern. -pub struct HttpServiceBuilder> { +pub struct HttpServiceBuilder< + F, + S, + C1 = h1::DefaultControlService, + C2 = h2::DefaultControlService, +> { config: ServiceConfig, - expect: X, - upgrade: Option, - on_request: Option, + h1_control: C1, + h2_control: C2, _t: PhantomData<(F, S)>, } -impl HttpServiceBuilder> { +impl HttpServiceBuilder { /// Create instance of `ServiceConfigBuilder` pub fn new() -> Self { HttpServiceBuilder::with_config(ServiceConfig::default()) @@ -37,26 +36,25 @@ impl HttpServiceBuilder> { pub fn with_config(config: ServiceConfig) -> Self { HttpServiceBuilder { config, - expect: ExpectHandler, - upgrade: None, - on_request: None, + h1_control: h1::DefaultControlService, + h2_control: h2::DefaultControlService, _t: PhantomData, } } } -impl HttpServiceBuilder +impl HttpServiceBuilder where F: Filter, S: ServiceFactory + 'static, S::Error: ResponseError, S::InitError: fmt::Debug, - X: ServiceFactory + 'static, - X::Error: ResponseError, - X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io, Codec), Response = ()> + 'static, - U::Error: fmt::Display + Error, - U::InitError: fmt::Debug, + C1: ServiceFactory, Response = h1::ControlAck>, + C1::Error: Error, + C1::InitError: fmt::Debug, + C2: ServiceFactory, Response = h2::ControlResult>, + C2::Error: Error, + C2::InitError: fmt::Debug, { /// Set server keep-alive setting. /// @@ -138,70 +136,24 @@ where self } - #[doc(hidden)] - /// Configure http2 connection settings - pub fn configure_http2(self, f: O) -> Self + /// Provide control service for http/1. + pub fn h1_control(self, control: CF) -> HttpServiceBuilder where - O: FnOnce(&h2::Config) -> R, - { - let _ = f(&self.config.h2config); - self - } - - /// Provide service for `EXPECT: 100-Continue` support. - /// - /// Service get called with request that contains `EXPECT` header. - /// Service must return request in case of success, in that case - /// request will be forwarded to main service. - pub fn expect(self, expect: XF) -> HttpServiceBuilder - where - XF: IntoServiceFactory, - X1: ServiceFactory, - X1::InitError: fmt::Debug, + CF: IntoServiceFactory>, + CT: ServiceFactory, Response = h1::ControlAck>, + CT::Error: Error, + CT::InitError: fmt::Debug, { HttpServiceBuilder { config: self.config, - expect: expect.into_factory(), - upgrade: self.upgrade, - on_request: self.on_request, + h2_control: self.h2_control, + h1_control: control.into_factory(), _t: PhantomData, } } - /// Provide service for custom `Connection: UPGRADE` support. - /// - /// If service is provided then normal requests handling get halted - /// and this service get called with original request and framed object. - pub fn upgrade(self, upgrade: UF) -> HttpServiceBuilder - where - UF: IntoServiceFactory, Codec)>, - U1: ServiceFactory<(Request, Io, Codec), Response = ()>, - U1::Error: fmt::Display + Error, - U1::InitError: fmt::Debug, - { - HttpServiceBuilder { - config: self.config, - expect: self.expect, - upgrade: Some(upgrade.into_factory()), - on_request: self.on_request, - _t: PhantomData, - } - } - - /// Set req request callback. - /// - /// It get called once per request. - pub fn on_request(mut self, f: FR) -> Self - where - FR: IntoService, - R: Service<(Request, IoRef), Response = Request, Error = Response> + 'static, - { - self.on_request = Some(boxed::service(f.into_service())); - self - } - /// Finish service configuration and create *http service* for HTTP/1 protocol. - pub fn h1(self, service: SF) -> H1Service + pub fn h1(self, service: SF) -> H1Service where B: MessageBody, SF: IntoServiceFactory, @@ -209,14 +161,36 @@ where S::InitError: fmt::Debug, S::Response: Into>, { - H1Service::with_config(self.config, service.into_factory()) - .expect(self.expect) - .upgrade(self.upgrade) - .on_request(self.on_request) + H1Service::with_config(self.config, service.into_factory()).control(self.h1_control) + } + + /// Provide control service for http/2 protocol. + pub fn h2_control(self, control: CF) -> HttpServiceBuilder + where + CF: IntoServiceFactory>, + CT: ServiceFactory, Response = h2::ControlResult>, + CT::Error: Error, + CT::InitError: fmt::Debug, + { + HttpServiceBuilder { + config: self.config, + h1_control: self.h1_control, + h2_control: control.into_factory(), + _t: PhantomData, + } + } + + /// Configure http2 connection settings + pub fn h2_configure(self, f: O) -> Self + where + O: FnOnce(&h2::Config) -> R, + { + let _ = f(&self.config.h2config); + self } /// Finish service configuration and create *http service* for HTTP/2 protocol. - pub fn h2(self, service: SF) -> H2Service + pub fn h2(self, service: SF) -> H2Service where B: MessageBody + 'static, SF: IntoServiceFactory, @@ -224,11 +198,11 @@ where S::InitError: fmt::Debug, S::Response: Into> + 'static, { - H2Service::with_config(self.config, service.into_factory()) + H2Service::with_config(self.config, service.into_factory()).control(self.h2_control) } /// Finish service configuration and create `HttpService` instance. - pub fn finish(self, service: SF) -> HttpService + pub fn finish(self, service: SF) -> HttpService where B: MessageBody + 'static, SF: IntoServiceFactory, @@ -237,8 +211,7 @@ where S::Response: Into> + 'static, { HttpService::with_config(self.config, service.into_factory()) - .expect(self.expect) - .upgrade(self.upgrade) - .on_request(self.on_request) + .h1_control(self.h1_control) + .h2_control(self.h2_control) } } diff --git a/ntex/src/http/client/error.rs b/ntex/src/http/client/error.rs index c097a25e..e7ef3119 100644 --- a/ntex/src/http/client/error.rs +++ b/ntex/src/http/client/error.rs @@ -7,7 +7,7 @@ use thiserror::Error; #[cfg(feature = "openssl")] use crate::connect::openssl::{HandshakeError, SslError}; -use crate::http::error::{HttpError, ParseError, PayloadError}; +use crate::http::error::{DecodeError, EncodeError, HttpError, PayloadError}; use crate::util::Either; /// A set of errors that can occur during parsing json payloads @@ -142,9 +142,12 @@ pub enum SendRequestError { /// Error sending request #[error("Error sending request: {0}")] Send(#[from] io::Error), + /// Error encoding request + #[error("Error during request encoding: {0}")] + Request(#[from] EncodeError), /// Error parsing response #[error("Error during response parsing: {0}")] - Response(#[from] ParseError), + Response(#[from] DecodeError), /// Http error #[error("{0}")] Http(#[from] HttpError), @@ -162,17 +165,17 @@ pub enum SendRequestError { Error(#[from] Box), } -impl From> for SendRequestError { - fn from(err: Either) -> Self { +impl From> for SendRequestError { + fn from(err: Either) -> Self { match err { - Either::Left(err) => SendRequestError::Send(err), + Either::Left(err) => SendRequestError::Request(err), Either::Right(err) => SendRequestError::Send(err), } } } -impl From> for SendRequestError { - fn from(err: Either) -> Self { +impl From> for SendRequestError { + fn from(err: Either) -> Self { match err { Either::Left(err) => SendRequestError::Response(err), Either::Right(err) => SendRequestError::Send(err), diff --git a/ntex/src/http/config.rs b/ntex/src/http/config.rs index c5a3e65e..ce7d9dde 100644 --- a/ntex/src/http/config.rs +++ b/ntex/src/http/config.rs @@ -2,10 +2,8 @@ use std::{cell::Cell, ptr::copy_nonoverlapping, rc::Rc, time}; use ntex_h2::{self as h2}; -use crate::http::{Request, Response}; -use crate::service::{boxed::BoxService, Pipeline}; use crate::time::{sleep, Millis, Seconds}; -use crate::{io::IoRef, util::BytesMut}; +use crate::{service::Pipeline, util::BytesMut}; #[derive(Debug, PartialEq, Eq, Clone, Copy)] /// Server keep-alive setting @@ -236,12 +234,9 @@ impl ServiceConfig { } } -pub(super) type OnRequest = BoxService<(Request, IoRef), Request, Response>; - -pub(super) struct DispatcherConfig { +pub(super) struct DispatcherConfig { pub(super) service: Pipeline, - pub(super) expect: Pipeline, - pub(super) upgrade: Option>, + pub(super) control: Pipeline, pub(super) keep_alive: Seconds, pub(super) client_disconnect: Seconds, pub(super) h2config: h2::Config, @@ -249,22 +244,13 @@ pub(super) struct DispatcherConfig { pub(super) headers_read_rate: Option, pub(super) payload_read_rate: Option, pub(super) timer: DateService, - pub(super) on_request: Option>, } -impl DispatcherConfig { - pub(super) fn new( - cfg: ServiceConfig, - service: S, - expect: X, - upgrade: Option, - on_request: Option, - ) -> Self { +impl DispatcherConfig { + pub(super) fn new(cfg: ServiceConfig, service: S, control: C) -> Self { DispatcherConfig { service: service.into(), - expect: expect.into(), - upgrade: upgrade.map(|v| v.into()), - on_request: on_request.map(|v| v.into()), + control: control.into(), keep_alive: cfg.keep_alive, client_disconnect: cfg.client_disconnect, ka_enabled: cfg.ka_enabled, diff --git a/ntex/src/http/error.rs b/ntex/src/http/error.rs index f32fcdde..f13552bf 100644 --- a/ntex/src/http/error.rs +++ b/ntex/src/http/error.rs @@ -57,9 +57,24 @@ impl ResponseError for io::Error {} /// `InternalServerError` for `JsonError` impl ResponseError for serde_json::error::Error {} +/// A set of errors that can occur during HTTP streams encoding +#[derive(thiserror::Error, Debug)] +pub enum EncodeError { + /// An invalid `HttpVersion`, such as `HTP/1.1` + #[error("Unsupported HTTP version specified")] + UnsupportedVersion(super::Version), + + #[error("Unexpected end of bytes stream")] + UnexpectedEof, + + /// Internal error + #[error("Internal error")] + Internal(Box), +} + /// A set of errors that can occur during parsing HTTP streams #[derive(thiserror::Error, Debug)] -pub enum ParseError { +pub enum DecodeError { /// An invalid `Method`, such as `GE.T`. #[error("Invalid Method specified")] Method, @@ -74,17 +89,13 @@ pub enum ParseError { Header, /// A message head is too large to be reasonable. #[error("Message head is too large")] - TooLarge, + TooLarge(usize), /// A message reached EOF, but is not complete. #[error("Message is incomplete")] Incomplete, /// An invalid `Status`, such as `1337 ELITE`. #[error("Invalid Status provided")] Status, - /// A timeout occurred waiting for an IO event. - #[allow(dead_code)] - #[error("Timeout during parse")] - Timeout, /// An `InvalidInput` occurred while trying to parse incoming stream. #[error("`InvalidInput` occurred while trying to parse incoming stream: {0}")] InvalidInput(&'static str), @@ -93,22 +104,22 @@ pub enum ParseError { Utf8(#[from] Utf8Error), } -impl From for ParseError { - fn from(err: FromUtf8Error) -> ParseError { - ParseError::Utf8(err.utf8_error()) +impl From for DecodeError { + fn from(err: FromUtf8Error) -> DecodeError { + DecodeError::Utf8(err.utf8_error()) } } -impl From for ParseError { - fn from(err: httparse::Error) -> ParseError { +impl From for DecodeError { + fn from(err: httparse::Error) -> DecodeError { match err { httparse::Error::HeaderName | httparse::Error::HeaderValue | httparse::Error::NewLine - | httparse::Error::Token => ParseError::Header, - httparse::Error::Status => ParseError::Status, - httparse::Error::TooManyHeaders => ParseError::TooLarge, - httparse::Error::Version => ParseError::Version, + | httparse::Error::Token => DecodeError::Header, + httparse::Error::Status => DecodeError::Status, + httparse::Error::TooManyHeaders => DecodeError::TooLarge(0), + httparse::Error::Version => DecodeError::Version, } } } @@ -131,9 +142,9 @@ pub enum PayloadError { /// Http2 payload error #[error("{0}")] Http2Payload(#[from] h2::StreamError), - /// Parse error - #[error("Parse error: {0}")] - Parse(#[from] ParseError), + /// Decode error + #[error("Decode error: {0}")] + Decode(#[from] DecodeError), /// Io error #[error("{0}")] Io(#[from] io::Error), @@ -153,61 +164,11 @@ impl From> for PayloadError { pub enum DispatchError { /// Service error #[error("Service error")] - Service(Box), + Service(Box), - /// Upgrade service error - #[error("Upgrade service error: {0}")] - Upgrade(Box), - - /// Peer is disconnected, error indicates that peer is disconnected because of it - #[error("Disconnected: {0:?}")] - PeerGone(Option), - - /// Http request parse error. - #[error("Parse error: {0}")] - Parse(#[from] ParseError), - - /// Http response encoding error. - #[error("Encode error: {0}")] - Encode(io::Error), - - /// Http/2 error - #[error("{0}")] - H2(#[from] H2Error), - - /// The first request did not complete within the specified timeout. - #[error("The first request did not complete within the specified timeout")] - SlowRequestTimeout, - - /// Disconnect timeout. Makes sense for ssl streams. - #[error("Connection shutdown timeout")] - DisconnectTimeout, - - /// Payload is not consumed - #[error("Task is completed but request's payload is not consumed")] - PayloadIsNotConsumed, - - /// Malformed request - #[error("Malformed request")] - MalformedRequest, - - /// Response body processing error - #[error("Response body processing error: {0}")] - ResponsePayload(Box), - - /// Internal error - #[error("Internal error")] - InternalError, - - /// Unknown error - #[error("Unknown error")] - Unknown, -} - -impl From for DispatchError { - fn from(err: io::Error) -> Self { - DispatchError::PeerGone(Some(err)) - } + /// Control service error + #[error("Control service error: {0}")] + Control(Box), } #[derive(thiserror::Error, Debug)] @@ -296,15 +257,16 @@ mod tests { #[test] fn test_payload_error() { - let err: PayloadError = io::Error::new(io::ErrorKind::Other, "ParseError").into(); - assert!(format!("{}", err).contains("ParseError")); + let err: PayloadError = io::Error::new(io::ErrorKind::Other, "DecodeError").into(); + assert!(format!("{}", err).contains("DecodeError")); let err: PayloadError = BlockingError::Canceled.into(); assert!(format!("{}", err).contains("Operation is canceled")); let err: PayloadError = - BlockingError::Error(io::Error::new(io::ErrorKind::Other, "ParseError")).into(); - assert!(format!("{}", err).contains("ParseError")); + BlockingError::Error(io::Error::new(io::ErrorKind::Other, "DecodeError")) + .into(); + assert!(format!("{}", err).contains("DecodeError")); let err = PayloadError::Incomplete(None); assert_eq!( @@ -315,7 +277,7 @@ mod tests { macro_rules! from { ($from:expr => $error:pat) => { - match ParseError::from($from) { + match DecodeError::from($from) { e @ $error => { assert!(format!("{}", e).len() >= 5); } @@ -326,13 +288,13 @@ mod tests { #[test] fn test_from() { - from!(httparse::Error::HeaderName => ParseError::Header); - from!(httparse::Error::HeaderName => ParseError::Header); - from!(httparse::Error::HeaderValue => ParseError::Header); - from!(httparse::Error::NewLine => ParseError::Header); - from!(httparse::Error::Status => ParseError::Status); - from!(httparse::Error::Token => ParseError::Header); - from!(httparse::Error::TooManyHeaders => ParseError::TooLarge); - from!(httparse::Error::Version => ParseError::Version); + from!(httparse::Error::HeaderName => DecodeError::Header); + from!(httparse::Error::HeaderName => DecodeError::Header); + from!(httparse::Error::HeaderValue => DecodeError::Header); + from!(httparse::Error::NewLine => DecodeError::Header); + from!(httparse::Error::Status => DecodeError::Status); + from!(httparse::Error::Token => DecodeError::Header); + from!(httparse::Error::TooManyHeaders => DecodeError::TooLarge(0)); + from!(httparse::Error::Version => DecodeError::Version); } } diff --git a/ntex/src/http/h1/client.rs b/ntex/src/http/h1/client.rs index d1f11b5e..b222d178 100644 --- a/ntex/src/http/h1/client.rs +++ b/ntex/src/http/h1/client.rs @@ -1,11 +1,11 @@ -use std::{cell::Cell, cell::RefCell, io}; +use std::{cell::Cell, cell::RefCell}; use bitflags::bitflags; use crate::codec::{Decoder, Encoder}; use crate::http::body::BodySize; use crate::http::config::DateService; -use crate::http::error::{ParseError, PayloadError}; +use crate::http::error::{DecodeError, EncodeError, PayloadError}; use crate::http::message::{ConnectionType, RequestHeadType, ResponseHead}; use crate::http::{Method, Version}; use crate::util::{Bytes, BytesMut}; @@ -117,7 +117,7 @@ impl ClientPayloadCodec { impl Decoder for ClientCodec { type Item = ResponseHead; - type Error = ParseError; + type Error = DecodeError; fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { debug_assert!( @@ -191,7 +191,7 @@ impl Decoder for ClientPayloadCodec { impl Encoder for ClientCodec { type Item = Message<(RequestHeadType, BodySize)>; - type Error = io::Error; + type Error = EncodeError; fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { diff --git a/ntex/src/http/h1/codec.rs b/ntex/src/http/h1/codec.rs index eec7f09c..35c8cc56 100644 --- a/ntex/src/http/h1/codec.rs +++ b/ntex/src/http/h1/codec.rs @@ -1,11 +1,11 @@ -use std::{cell::Cell, fmt, io}; +use std::{cell::Cell, fmt}; use bitflags::bitflags; use crate::codec::{Decoder, Encoder}; use crate::http::body::BodySize; use crate::http::config::DateService; -use crate::http::error::ParseError; +use crate::http::error::{DecodeError, EncodeError}; use crate::http::message::ConnectionType; use crate::http::request::Request; use crate::http::response::Response; @@ -124,7 +124,7 @@ impl Codec { impl Decoder for Codec { type Item = (Request, PayloadType); - type Error = ParseError; + type Error = DecodeError; fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { if let Some((req, payload)) = self.decoder.decode(src)? { @@ -155,7 +155,7 @@ impl Decoder for Codec { impl Encoder for Codec { type Item = Message<(Response<()>, BodySize)>; - type Error = io::Error; + type Error = EncodeError; fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { diff --git a/ntex/src/http/h1/control.rs b/ntex/src/http/h1/control.rs new file mode 100644 index 00000000..79b338e0 --- /dev/null +++ b/ntex/src/http/h1/control.rs @@ -0,0 +1,471 @@ +use std::{future::Future, io}; + +use crate::http::message::CurrentIo; +use crate::http::{body::Body, h1::Codec, Request, Response, ResponseError}; +use crate::io::{Filter, Io, IoBoxed, IoRef}; + +pub enum Control { + /// New connection + NewConnection(Connection), + /// New request is loaded + NewRequest(NewRequest), + /// Handle `Connection: UPGRADE` + Upgrade(Upgrade), + /// Handle `EXPECT` header + Expect(Expect), + /// Underlying transport connection closed + Closed(Closed), + /// Application level error + Error(Error), + /// Protocol level error + ProtocolError(ProtocolError), + /// Peer is gone + PeerGone(PeerGone), +} + +/// Control message handling result +#[derive(Debug)] +pub struct ControlAck { + pub(super) result: ControlResult, + pub(super) flags: ControlFlags, +} + +bitflags::bitflags! { + #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] + pub(super) struct ControlFlags: u8 { + /// Disconnect after request handling + const DISCONNECT = 0b0000_0001; + /// Handle expect-continue + const CONTINUE = 0b0001_0000; + } +} + +#[derive(Debug)] +pub(super) enum ControlResult { + // handle request expect + Expect(Request), + // handle request upgrade + Upgrade(Request), + // forward request to publish service + Publish(Request), + // forward request to publish service + PublishUpgrade(Request), + // send response + Response(Response<()>, Body), + // send response + ResponseWithIo(Response<()>, Body, IoBoxed), + // drop connection + Stop, +} + +impl Control { + pub(super) fn err(err: Err) -> Self + where + Err: ResponseError, + { + Control::Error(Error::new(err)) + } + + pub(super) const fn closed() -> Self { + Control::Closed(Closed) + } + + pub(super) fn new_req(req: Request) -> Self { + Control::NewRequest(NewRequest(req)) + } + + pub(super) fn con(io: IoRef) -> Self { + Control::NewConnection(Connection { io }) + } + + pub(super) fn upgrade(req: Request, io: Io, codec: Codec) -> Self { + Control::Upgrade(Upgrade { req, io, codec }) + } + + pub(super) fn expect(req: Request) -> Self { + Control::Expect(Expect(req)) + } + + pub(super) fn peer_gone(err: Option) -> Self { + Control::PeerGone(PeerGone(err)) + } + + pub(super) fn proto_err(err: super::ProtocolError) -> Self { + Control::ProtocolError(ProtocolError(err)) + } + + #[inline] + /// Ack control message + pub fn ack(self) -> ControlAck + where + F: Filter, + Err: ResponseError, + { + match self { + Control::NewConnection(msg) => msg.ack(), + Control::NewRequest(msg) => msg.ack(), + Control::Upgrade(msg) => msg.ack(), + Control::Expect(msg) => msg.ack(), + Control::Closed(msg) => msg.ack(), + Control::Error(msg) => msg.ack(), + Control::ProtocolError(msg) => msg.ack(), + Control::PeerGone(msg) => msg.ack(), + } + } +} + +#[derive(Debug)] +pub struct Connection { + io: IoRef, +} + +impl Connection { + #[inline] + /// Returns reference to Io + pub fn io(&self) -> &IoRef { + &self.io + } + + #[inline] + /// Ack and continue handling process + pub fn ack(self) -> ControlAck { + ControlAck { + result: ControlResult::Stop, + flags: ControlFlags::empty(), + } + } + + #[inline] + /// Drop connection + pub fn disconnect(self) -> ControlAck { + ControlAck { + result: ControlResult::Stop, + flags: ControlFlags::DISCONNECT, + } + } +} + +#[derive(Debug)] +pub struct NewRequest(Request); + +impl NewRequest { + #[inline] + /// Returns reference to http request + pub fn get_ref(&self) -> &Request { + &self.0 + } + + #[inline] + /// Returns mut reference to http request + pub fn get_mut(&mut self) -> &mut Request { + &mut self.0 + } + + #[inline] + /// Ack new request and continue handling process + pub fn ack(self) -> ControlAck { + let result = if self.0.head().expect() { + ControlResult::Expect(self.0) + } else if self.0.upgrade() { + ControlResult::Upgrade(self.0) + } else { + ControlResult::Publish(self.0) + }; + ControlAck { + result, + flags: ControlFlags::empty(), + } + } + + #[inline] + /// Fail request handling + pub fn fail(self, err: E) -> ControlAck { + let res: Response = (&err).into(); + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::empty(), + } + } + + #[inline] + /// Fail request and send custom response + pub fn fail_with(self, res: Response) -> ControlAck { + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::empty(), + } + } +} + +#[derive(Debug)] +pub struct Upgrade { + req: Request, + io: Io, + codec: Codec, +} + +impl Upgrade { + #[inline] + /// Returns reference to Io + pub fn io(&self) -> &Io { + &self.io + } + + #[inline] + /// Returns reference to http request + pub fn get_ref(&self) -> &Request { + &self.req + } + + #[inline] + /// Returns mut reference to http request + pub fn get_mut(&mut self) -> &mut Request { + &mut self.req + } + + #[inline] + /// Ack upgrade request and continue handling process + pub fn ack(mut self) -> ControlAck { + // Move io into request + let io: IoBoxed = self.io.into(); + io.stop_timer(); + self.req.head_mut().io = CurrentIo::new(io, self.codec); + + ControlAck { + result: ControlResult::PublishUpgrade(self.req), + flags: ControlFlags::DISCONNECT, + } + } + + #[inline] + /// Handle upgrade request + pub fn handle(self, f: H) -> ControlAck + where + H: FnOnce(Request, Io, Codec) -> R + 'static, + R: Future, + { + crate::rt::spawn(async move { + let _ = f(self.req, self.io, self.codec).await; + }); + ControlAck { + result: ControlResult::Stop, + flags: ControlFlags::DISCONNECT, + } + } + + #[inline] + /// Fail request handling + pub fn fail(self, err: E) -> ControlAck { + let res: Response = (&err).into(); + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::ResponseWithIo(res, body.into(), self.io.into()), + flags: ControlFlags::DISCONNECT, + } + } + + #[inline] + /// Fail request and send custom response + pub fn fail_with(self, res: Response) -> ControlAck { + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::ResponseWithIo(res, body.into(), self.io.into()), + flags: ControlFlags::DISCONNECT, + } + } +} + +/// Connection closed message +#[derive(Debug)] +pub struct Closed; + +impl Closed { + #[inline] + /// convert packet to a result + pub fn ack(self) -> ControlAck { + ControlAck { + result: ControlResult::Stop, + flags: ControlFlags::empty(), + } + } +} + +/// Service level error +#[derive(Debug)] +pub struct Error { + err: Err, + pkt: Response, +} + +impl Error { + fn new(err: Err) -> Self { + Self { + pkt: err.error_response(), + err, + } + } + + #[inline] + /// Returns reference to http error + pub fn get_ref(&self) -> &Err { + &self.err + } + + #[inline] + /// Ack service error and close connection. + pub fn ack(self) -> ControlAck { + let (res, body) = self.pkt.into_parts(); + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::DISCONNECT, + } + } + + #[inline] + /// Fail error handling + pub fn fail(self, err: E) -> ControlAck { + let res: Response = (&err).into(); + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::DISCONNECT, + } + } + + #[inline] + /// Fail error handling + pub fn fail_with(self, res: Response) -> ControlAck { + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::DISCONNECT, + } + } +} + +#[derive(Debug)] +pub struct ProtocolError(super::ProtocolError); + +impl ProtocolError { + #[inline] + /// Returns error reference + pub fn err(&self) -> &super::ProtocolError { + &self.0 + } + + #[inline] + /// Ack ProtocolError message + pub fn ack(self) -> ControlAck { + let (res, body) = self.0.error_response().into_parts(); + + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::DISCONNECT, + } + } + + #[inline] + /// Fail error handling + pub fn fail(self, err: E) -> ControlAck { + let res: Response = (&err).into(); + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::DISCONNECT, + } + } + + #[inline] + /// Fail error handling + pub fn fail_with(self, res: Response) -> ControlAck { + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::DISCONNECT, + } + } +} + +#[derive(Debug)] +pub struct PeerGone(Option); + +impl PeerGone { + #[inline] + /// Returns error reference + pub fn err(&self) -> Option<&io::Error> { + self.0.as_ref() + } + + #[inline] + /// Take error + pub fn take(&mut self) -> Option { + self.0.take() + } + + #[inline] + /// Ack PeerGone message + pub fn ack(self) -> ControlAck { + ControlAck { + result: ControlResult::Stop, + flags: ControlFlags::DISCONNECT, + } + } +} + +#[derive(Debug)] +pub struct Expect(Request); + +impl Expect { + #[inline] + /// Returns reference to http request + pub fn get_ref(&self) -> &Request { + &self.0 + } + + #[inline] + /// Ack expect request + pub fn ack(self) -> ControlAck { + let result = if self.0.upgrade() { + ControlResult::Upgrade(self.0) + } else { + ControlResult::Publish(self.0) + }; + ControlAck { + result, + flags: ControlFlags::CONTINUE, + } + } + + #[inline] + /// Fail expect request + pub fn fail(self, err: E) -> ControlAck { + let res: Response = (&err).into(); + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::DISCONNECT, + } + } + + #[inline] + /// Fail expect request and send custom response + pub fn fail_with(self, res: Response) -> ControlAck { + let (res, body) = res.into_parts(); + + ControlAck { + result: ControlResult::Response(res, body.into()), + flags: ControlFlags::DISCONNECT, + } + } +} diff --git a/ntex/src/http/h1/decoder.rs b/ntex/src/http/h1/decoder.rs index a5c82e50..16160337 100644 --- a/ntex/src/http/h1/decoder.rs +++ b/ntex/src/http/h1/decoder.rs @@ -4,7 +4,7 @@ use ntex_http::header::{HeaderName, HeaderValue}; use ntex_http::{header, Method, StatusCode, Uri, Version}; use crate::codec::Decoder; -use crate::http::error::ParseError; +use crate::http::error::DecodeError; use crate::http::header::HeaderMap; use crate::http::message::{ConnectionType, ResponseHead}; use crate::http::request::Request; @@ -40,7 +40,7 @@ impl Clone for MessageDecoder { impl Decoder for MessageDecoder { type Item = (T, PayloadType); - type Error = ParseError; + type Error = DecodeError; fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { T::decode(src) @@ -78,14 +78,14 @@ pub(super) trait MessageType: Sized { fn headers_mut(&mut self) -> &mut HeaderMap; - fn decode(src: &mut BytesMut) -> Result, ParseError>; + fn decode(src: &mut BytesMut) -> Result, DecodeError>; fn set_headers( &mut self, slice: &Bytes, version: Version, raw_headers: &[HeaderIndex], - ) -> Result { + ) -> Result { let mut ka = None; let mut has_upgrade = false; let mut expect = false; @@ -108,12 +108,12 @@ pub(super) trait MessageType: Sized { match name { header::CONTENT_LENGTH if content_length.is_some() || chunked => { log::debug!("multiple Content-Length not allowed"); - return Err(ParseError::Header); + return Err(DecodeError::Header); } header::CONTENT_LENGTH => match value.to_str() { Ok(s) if s.trim_start().starts_with('+') => { log::debug!("illegal Content-Length: {:?}", s); - return Err(ParseError::Header); + return Err(DecodeError::Header); } Ok(s) => { if let Ok(len) = s.parse::() { @@ -122,18 +122,18 @@ pub(super) trait MessageType: Sized { content_length = Some(len); } else { log::debug!("illegal Content-Length: {:?}", s); - return Err(ParseError::Header); + return Err(DecodeError::Header); } } Err(_) => { log::debug!("illegal Content-Length: {:?}", value); - return Err(ParseError::Header); + return Err(DecodeError::Header); } }, // transfer-encoding header::TRANSFER_ENCODING if seen_te => { log::debug!("Transfer-Encoding header usage is not allowed"); - return Err(ParseError::Header); + return Err(DecodeError::Header); } header::TRANSFER_ENCODING if version == Version::HTTP_11 => { seen_te = true; @@ -145,10 +145,10 @@ pub(super) trait MessageType: Sized { // allow silently since multiple TE headers are already checked } else { log::debug!("illegal Transfer-Encoding: {:?}", s); - return Err(ParseError::Header); + return Err(DecodeError::Header); } } else { - return Err(ParseError::Header); + return Err(DecodeError::Header); } } // connection keep-alive state @@ -228,7 +228,7 @@ impl MessageType for Request { &mut self.head_mut().headers } - fn decode(src: &mut BytesMut) -> Result, ParseError> { + fn decode(src: &mut BytesMut) -> Result, DecodeError> { let mut headers: [mem::MaybeUninit; MAX_HEADERS] = uninit_array(); let (len, method, uri, ver, headers) = { @@ -240,7 +240,7 @@ impl MessageType for Request { match req.parse_with_uninit_headers(src, &mut parsed)? { httparse::Status::Complete(len) => { let method = Method::from_bytes(req.method.unwrap().as_bytes()) - .map_err(|_| ParseError::Method)?; + .map_err(|_| DecodeError::Method)?; let uri = Uri::try_from(req.path.unwrap())?; let version = if req.version.unwrap() == 1 { Version::HTTP_11 @@ -259,7 +259,7 @@ impl MessageType for Request { httparse::Status::Partial => { if src.len() >= MAX_BUFFER_SIZE { trace!("MAX_BUFFER_SIZE unprocessed data reached, closing"); - return Err(ParseError::TooLarge); + return Err(DecodeError::TooLarge(src.len())); } return Ok(None); } @@ -275,7 +275,7 @@ impl MessageType for Request { // see https://datatracker.ietf.org/doc/html/rfc1945#section-7.2.2 if ver == Version::HTTP_10 && method == Method::POST && length.is_none() { debug!("no Content-Length specified for HTTP/1.0 POST request"); - return Err(ParseError::Header); + return Err(DecodeError::Header); } // Remove CL value if 0 now that all headers and HTTP/1.0 special cases are processed. @@ -325,7 +325,7 @@ impl MessageType for ResponseHead { &mut self.headers } - fn decode(src: &mut BytesMut) -> Result, ParseError> { + fn decode(src: &mut BytesMut) -> Result, DecodeError> { let mut headers: [mem::MaybeUninit; MAX_HEADERS] = uninit_array(); let (len, ver, status, headers) = { @@ -345,7 +345,7 @@ impl MessageType for ResponseHead { Version::HTTP_10 }; let status = StatusCode::from_u16(res.code.unwrap()) - .map_err(|_| ParseError::Status)?; + .map_err(|_| DecodeError::Status)?; ( len, @@ -357,7 +357,7 @@ impl MessageType for ResponseHead { httparse::Status::Partial => { return if src.len() >= MAX_BUFFER_SIZE { log::error!("MAX_BUFFER_SIZE unprocessed data reached, closing"); - Err(ParseError::TooLarge) + Err(DecodeError::TooLarge(src.len())) } else { Ok(None) }; @@ -514,7 +514,7 @@ enum ChunkedState { impl Decoder for PayloadDecoder { type Item = PayloadItem; - type Error = ParseError; + type Error = DecodeError; fn decode(&self, src: &mut BytesMut) -> Result, Self::Error> { let mut kind = self.kind.get(); @@ -595,7 +595,7 @@ impl ChunkedState { body: &mut BytesMut, size: &mut u64, buf: &mut Option, - ) -> Poll> { + ) -> Poll> { use self::ChunkedState::*; match *self { Size => ChunkedState::read_size(body, size), @@ -614,7 +614,7 @@ impl ChunkedState { fn read_size( rdr: &mut BytesMut, size: &mut u64, - ) -> Poll> { + ) -> Poll> { let rem = match byte!(rdr) { b @ b'0'..=b'9' => b - b'0', b @ b'a'..=b'f' => b + 10 - b'a', @@ -623,7 +623,7 @@ impl ChunkedState { b';' => return Poll::Ready(Ok(ChunkedState::Extension)), b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)), _ => { - return Poll::Ready(Err(ParseError::InvalidInput( + return Poll::Ready(Err(DecodeError::InvalidInput( "Invalid chunk size line: Invalid Size", ))); } @@ -638,43 +638,43 @@ impl ChunkedState { } None => { log::debug!("chunk size would overflow u64"); - Poll::Ready(Err(ParseError::InvalidInput( + Poll::Ready(Err(DecodeError::InvalidInput( "Invalid chunk size line: Size is too big", ))) } } } - fn read_size_lws(rdr: &mut BytesMut) -> Poll> { + fn read_size_lws(rdr: &mut BytesMut) -> Poll> { log::trace!("read_size_lws"); match byte!(rdr) { // LWS can follow the chunk size, but no more digits can come b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)), b';' => Poll::Ready(Ok(ChunkedState::Extension)), b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), - _ => Poll::Ready(Err(ParseError::InvalidInput( + _ => Poll::Ready(Err(DecodeError::InvalidInput( "Invalid chunk size linear white space", ))), } } - fn read_extension(rdr: &mut BytesMut) -> Poll> { + fn read_extension(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)), // strictly 0x20 (space) should be disallowed but we don't parse quoted strings here - 0x00..=0x08 | 0x0a..=0x1f | 0x7f => Poll::Ready(Err(ParseError::InvalidInput( - "Invalid character in chunk extension", - ))), + 0x00..=0x08 | 0x0a..=0x1f | 0x7f => Poll::Ready(Err( + DecodeError::InvalidInput("Invalid character in chunk extension"), + )), _ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions } } fn read_size_lf( rdr: &mut BytesMut, size: &mut u64, - ) -> Poll> { + ) -> Poll> { match byte!(rdr) { b'\n' if *size > 0 => Poll::Ready(Ok(ChunkedState::Body)), b'\n' if *size == 0 => Poll::Ready(Ok(ChunkedState::EndCr)), - _ => Poll::Ready(Err(ParseError::InvalidInput("Invalid chunk size LF"))), + _ => Poll::Ready(Err(DecodeError::InvalidInput("Invalid chunk size LF"))), } } @@ -682,7 +682,7 @@ impl ChunkedState { rdr: &mut BytesMut, rem: &mut u64, buf: &mut Option, - ) -> Poll> { + ) -> Poll> { log::trace!("Chunked read, remaining={:?}", rem); let len = rdr.len() as u64; @@ -706,28 +706,28 @@ impl ChunkedState { } } - fn read_body_cr(rdr: &mut BytesMut) -> Poll> { + fn read_body_cr(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)), - _ => Poll::Ready(Err(ParseError::InvalidInput("Invalid chunk body CR"))), + _ => Poll::Ready(Err(DecodeError::InvalidInput("Invalid chunk body CR"))), } } - fn read_body_lf(rdr: &mut BytesMut) -> Poll> { + fn read_body_lf(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { b'\n' => Poll::Ready(Ok(ChunkedState::Size)), - _ => Poll::Ready(Err(ParseError::InvalidInput("Invalid chunk body LF"))), + _ => Poll::Ready(Err(DecodeError::InvalidInput("Invalid chunk body LF"))), } } - fn read_end_cr(rdr: &mut BytesMut) -> Poll> { + fn read_end_cr(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)), - _ => Poll::Ready(Err(ParseError::InvalidInput("Invalid chunk end CR"))), + _ => Poll::Ready(Err(DecodeError::InvalidInput("Invalid chunk end CR"))), } } - fn read_end_lf(rdr: &mut BytesMut) -> Poll> { + fn read_end_lf(rdr: &mut BytesMut) -> Poll> { match byte!(rdr) { b'\n' => Poll::Ready(Ok(ChunkedState::End)), - _ => Poll::Ready(Err(ParseError::InvalidInput("Invalid chunk end LF"))), + _ => Poll::Ready(Err(DecodeError::InvalidInput("Invalid chunk end LF"))), } } } diff --git a/ntex/src/http/h1/default.rs b/ntex/src/http/h1/default.rs new file mode 100644 index 00000000..4ec6f3a4 --- /dev/null +++ b/ntex/src/http/h1/default.rs @@ -0,0 +1,43 @@ +use std::io; + +use crate::http::ResponseError; +use crate::io::Filter; +use crate::service::{Service, ServiceCtx, ServiceFactory}; + +use super::control::{Control, ControlAck}; + +#[derive(Default)] +/// Default control service +pub struct DefaultControlService; + +impl ServiceFactory> for DefaultControlService +where + F: Filter, + Err: ResponseError, +{ + type Response = ControlAck; + type Error = io::Error; + type Service = DefaultControlService; + type InitError = io::Error; + + async fn create(&self, _: ()) -> Result { + Ok(DefaultControlService) + } +} + +impl Service> for DefaultControlService +where + F: Filter, + Err: ResponseError, +{ + type Response = ControlAck; + type Error = io::Error; + + async fn call( + &self, + req: Control, + _: ServiceCtx<'_, Self>, + ) -> Result { + Ok(req.ack()) + } +} diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index a6ed26bc..92a3e094 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1,36 +1,32 @@ -//! Framed transport dispatcher -use std::task::{Context, Poll}; -use std::{cell::RefCell, error::Error, future::Future, io, marker, pin::Pin, rc::Rc}; +//! HTTP/1 protocol dispatcher +use std::{error, future, io, marker, pin::Pin, rc::Rc, task::Context, task::Poll}; -use crate::io::{Decoded, Filter, Io, IoBoxed, IoRef, IoStatusUpdate, RecvError}; -use crate::service::{Pipeline, PipelineCall, Service}; +use crate::io::{Decoded, Filter, Io, IoBoxed, IoStatusUpdate, RecvError}; +use crate::service::{PipelineCall, Service}; use crate::time::Seconds; -use crate::util::{ready, Bytes}; +use crate::util::{ready, Either}; -use crate::http; use crate::http::body::{BodySize, MessageBody, ResponseBody}; -use crate::http::config::{DispatcherConfig, OnRequest}; -use crate::http::error::{DispatchError, ParseError, PayloadError, ResponseError}; +use crate::http::error::{PayloadError, ResponseError}; use crate::http::message::{ConnectionType, CurrentIo}; -use crate::http::request::Request; -use crate::http::response::Response; +use crate::http::{self, config::DispatcherConfig, request::Request, response::Response}; +use super::control::{Control, ControlAck, ControlFlags, ControlResult}; use super::decoder::{PayloadDecoder, PayloadItem, PayloadType}; use super::payload::{Payload, PayloadSender, PayloadStatus}; -use super::{codec::Codec, Message}; +use super::{codec::Codec, Message, ProtocolError}; bitflags::bitflags! { #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct Flags: u8 { - /// Upgrade request + /// Upgrade hnd const UPGRADE = 0b0000_0001; - /// Handling upgrade - const UPGRADE_HND = 0b0000_0010; - /// Stop after sending payload - const SENDPAYLOAD_AND_STOP = 0b0000_0100; - + /// Stopping + const SENDPAYLOAD_AND_STOP = 0b0000_0010; + /// Complete operation and disconnect + const DISCONNECT = 0b0000_0100; /// Keep-alive is enabled - const READ_KA_TIMEOUT = 0b0001_0000; + const READ_KA_TIMEOUT = 0b0001_0000; /// Read headers timer is enabled const READ_HDRS_TIMEOUT = 0b0010_0000; /// Read headers payload is enabled @@ -40,58 +36,50 @@ bitflags::bitflags! { pin_project_lite::pin_project! { /// Dispatcher for HTTP/1.1 protocol - pub struct Dispatcher, B, X: Service, U: Service<(Request, Io, Codec)>> - where S: 'static, X: 'static, U: 'static + pub struct Dispatcher, B, C: Service>> + where + F: 'static, + S::Error: 'static, { - #[pin] - call: CallState, - st: State, - inner: DispatcherInner, + st: State, + inner: DispatcherInner, } } -#[derive(Debug, thiserror::Error)] -enum State { - #[error("State::Call")] - Call, - #[error("State::ReadRequest")] +#[derive(Debug)] +enum State +where + F: 'static, + S: Service, + S::Error: 'static, + C: Service>, +{ + CallPublish { + fut: PipelineCall, + }, + CallControl { + fut: PipelineCall>, + }, ReadRequest, - #[error("State::ReadPayload")] ReadPayload, - #[error("State::SendPayload")] - SendPayload { body: ResponseBody }, - #[error("State::SendPayloadAndStop")] + SendPayload { + body: ResponseBody, + }, SendPayloadAndStop { body: ResponseBody, - boxed_io: Option>, + io: IoBoxed, + }, + Stop { + fut: Option>>, + io: Option, }, - #[error("State::Upgrade")] - Upgrade(Option), - #[error("State::StopIo")] - StopIo(Box<(IoBoxed, Codec)>), - #[error("State::Stop")] - Stop, } -pin_project_lite::pin_project! { - #[project = CallStateProject] - enum CallState, X: Service> - where S: 'static, X: 'static - { - None, - Service { #[pin] fut: PipelineCall }, - ServiceUpgrade { #[pin] fut: PipelineCall }, - Expect { #[pin] fut: PipelineCall }, - Filter { fut: PipelineCall } - } -} - -struct DispatcherInner { +struct DispatcherInner { io: Io, flags: Flags, codec: Codec, - config: Rc>, - error: Option, + config: Rc>, payload: Option<(PayloadDecoder, PayloadSender)>, read_remains: u32, read_consumed: u32, @@ -99,19 +87,17 @@ struct DispatcherInner { _t: marker::PhantomData<(S, B)>, } -impl Dispatcher +impl Dispatcher where F: Filter, + C: Service, Response = ControlAck>, S: Service, S::Error: ResponseError, S::Response: Into>, B: MessageBody, - X: Service, - X::Error: ResponseError, - U: Service<(Request, Io, Codec), Response = ()>, { /// Construct new `Dispatcher` instance with outgoing messages stream. - pub(in crate::http) fn new(io: Io, config: Rc>) -> Self { + pub(in crate::http) fn new(io: Io, config: Rc>) -> Self { let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled()); io.set_disconnect_timeout(config.client_disconnect); @@ -124,14 +110,12 @@ where }; Dispatcher { - call: CallState::None, st: State::ReadRequest, inner: DispatcherInner { io, flags, codec, config, - error: None, payload: None, read_remains: 0, read_consumed: 0, @@ -142,497 +126,230 @@ where } } -impl Future for Dispatcher +impl future::Future for Dispatcher where F: Filter, - S: Service, + C: Service, Response = ControlAck> + 'static, + C::Error: error::Error, + S: Service + 'static, S::Error: ResponseError + 'static, S::Response: Into>, B: MessageBody, - X: Service, - X::Error: ResponseError + 'static, - U: Service<(Request, Io, Codec), Response = ()> + 'static, { - type Output = Result<(), DispatchError>; + type Output = Result<(), Box>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut().project(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + let inner = &mut this.inner; loop { - match this.st { - State::Call => { - let next = match this.call.project() { - CallStateProject::Service { fut } => { - match fut.poll(cx) { - Poll::Ready(result) => match result { - Ok(res) => { - let (res, body) = res.into().into_parts(); - *this.st = this.inner.send_response(res, body); - } - Err(e) => *this.st = this.inner.handle_error(e, false), - }, - Poll::Pending => { - // we might need to read more data into a request payload - // (ie service future can wait for payload data) - if this.inner.payload.is_some() { - if let Err(e) = - ready!(this.inner.poll_request_payload(cx)) - { - *this.st = State::Stop; - this.inner.error = Some(e); - } - } else if this.inner.poll_io_closed(cx) { - // check if io is closed - *this.st = State::Stop; - } else { - return Poll::Pending; - } - } - } - None + *this.st = match this.st { + // handle publish service responses + State::CallPublish { fut } => match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(res)) => { + let (res, body) = res.into().into_parts(); + if inner.flags.contains(Flags::UPGRADE) { + inner.send_response_to(res, body, None) + } else { + inner.send_response(res, body) } - // special handling for upgrade requests. - // we cannot continue to handle requests, because Io get - // converted to IoBoxed before we set it to request, - // so we have to send response and disconnect. request payload - // handling should be handled by service - CallStateProject::ServiceUpgrade { fut } => { - match ready!(fut.poll(cx)) { - Ok(res) => { - let (msg, body) = res.into().into_parts(); - let io = if let Some(item) = msg.head().take_io() { - item - } else { - log::trace!("Handler service consumed io, stop"); - return Poll::Ready(Ok(())); - }; - - io.1.set_ctype(ConnectionType::Close); - io.1.unset_streaming(); - let result = io - .0 - .encode(Message::Item((msg, body.size())), &io.1); - if result.is_ok() { - match body.size() { - BodySize::None | BodySize::Empty => { - *this.st = State::StopIo(io) - } - _ => { - *this.st = State::SendPayloadAndStop { - body, - boxed_io: Some(io), - } - } - } - } else { - *this.st = State::StopIo(io); - } - } - Err(e) => { - log::error!( - "Cannot handle error for upgrade handler: {:?}", - e - ); - return Poll::Ready(Ok(())); - } - } - None - } - // handle EXPECT call - // expect service call must resolve before - // we can do any more io processing. - // - // TODO: check keep-alive timer interaction - CallStateProject::Expect { fut } => match ready!(fut.poll(cx)) { - Ok(req) => { - let result = this.inner.io.with_write_buf(|buf| { - buf.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n") - }); - if result.is_err() { - log::error!( - "{}: Expect handler returned error: {:?}", - this.inner.io.tag(), - result.err().unwrap() - ); - *this.st = State::Stop; - this = self.as_mut().project(); - continue; - } else if this.inner.flags.contains(Flags::UPGRADE) { - *this.st = State::Upgrade(Some(req)); - this = self.as_mut().project(); - continue; - } else if this.inner.flags.contains(Flags::UPGRADE_HND) { - Some(this.inner.service_upgrade(req)) - } else { - Some(this.inner.service_call(req)) - } - } - Err(e) => { - *this.st = this.inner.handle_error(e, true); - None - } - }, - // handle FILTER call - CallStateProject::Filter { fut } => { - match ready!(Pin::new(fut).poll(cx)) { - Ok(req) => { - this.inner - .codec - .set_ctype(req.head().connection_type()); - if req.head().expect() { - Some(this.inner.service_expect(req)) - } else if this.inner.flags.contains(Flags::UPGRADE_HND) - { - Some(this.inner.service_upgrade(req)) - } else { - Some(this.inner.service_call(req)) - } - } - Err(res) => { - let (res, body) = res.into_parts(); - *this.st = - this.inner.send_response(res, body.into_body()); - None - } - } - } - CallStateProject::None => unreachable!(), - }; - - this = self.as_mut().project(); - if let Some(next) = next { - this.call.set(next); } - } + Poll::Ready(Err(err)) => inner.control(Control::err(err)), + Poll::Pending => { + if !inner.flags.contains(Flags::UPGRADE) { + ready!(inner.poll_request(cx)) + } else { + return Poll::Pending; + } + } + }, + // handle control service responses + State::CallControl { fut } => match Pin::new(fut).poll(cx) { + Poll::Ready(Ok(ControlAck { result, flags })) => { + if flags.contains(ControlFlags::CONTINUE) { + let result = inner.io.with_write_buf(|buf| { + buf.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n") + }); + if let Err(err) = result { + *this.st = inner.ctl_peer_gone(Some(err)); + continue; + } + } + if flags.contains(ControlFlags::DISCONNECT) { + inner.flags.insert(Flags::DISCONNECT); + } + + match result { + ControlResult::Publish(req) => inner.publish(req), + ControlResult::PublishUpgrade(req) => { + inner.flags.insert(Flags::UPGRADE); + inner.publish(req) + } + ControlResult::Response(res, body) => { + inner.send_response(res, body.into()) + } + ControlResult::ResponseWithIo(res, body, io) => { + inner.send_response_to(res, body.into(), Some(io)) + } + ControlResult::Expect(req) => { + inner.control(Control::expect(req)) + } + ControlResult::Upgrade(req) => inner.ctl_upgrade(req), + ControlResult::Stop => inner.stop(), + } + } + Poll::Ready(Err(err)) => { + log::error!("{}: Control plain error: {}", inner.io.tag(), err); + return Poll::Ready(Err(Box::new(err))); + } + Poll::Pending => ready!(inner.poll_request(cx)), + }, // read request and call service - State::ReadRequest => { - *this.st = ready!(this.inner.read_request(cx, &mut this.call)); - } + State::ReadRequest => ready!(inner.poll_read_request(cx)), // consume request's payload State::ReadPayload => { - if let Err(e) = ready!(this.inner.poll_request_payload(cx)) { - *this.st = State::Stop; - this.inner.error = Some(e); - } else { - *this.st = State::ReadRequest; - } + ready!(inner.poll_request_payload(cx)).unwrap_or(State::ReadRequest) } // send response body - State::SendPayload { ref mut body } => { - if this.inner.io.is_closed() { - *this.st = State::Stop; - } else { - if let Poll::Ready(Err(err)) = this.inner.poll_request_payload(cx) { - this.inner.error = Some(err); - this.inner.flags.insert(Flags::SENDPAYLOAD_AND_STOP); - } - loop { - let _ = ready!(this.inner.io.poll_flush(cx, false)); - let item = ready!(body.poll_next_chunk(cx)); - if let Some(st) = this.inner.send_payload(item) { - *this.st = st; - break; - } - } - } + State::SendPayload { body } => { + ready!(inner.poll_send_payload(cx, body)) } // send response body - State::SendPayloadAndStop { - ref mut body, - ref mut boxed_io, - } => { - let io = boxed_io.as_ref().unwrap(); - - if io.0.is_closed() { - *this.st = State::Stop; - } else { - if let Poll::Ready(Err(err)) = - this.inner._poll_request_payload(Some(&io.0), cx) - { - this.inner.error = Some(err); - } - loop { - let _ = ready!(io.0.poll_flush(cx, false)); - let item = ready!(body.poll_next_chunk(cx)); - match item { - Some(Ok(item)) => { - trace!("got response chunk: {:?}", item.len()); - if let Err(e) = - io.0.encode(Message::Chunk(Some(item)), &io.1) - { - trace!("Cannot encode chunk: {:?}", e); - } else { - continue; - } - } - None => { - trace!("response payload eof {:?}", this.inner.flags); - if let Err(e) = io.0.encode(Message::Chunk(None), &io.1) - { - trace!("Cannot encode payload eof: {:?}", e); - } - } - Some(Err(e)) => { - trace!("error during response body poll: {:?}", e); - } - } - *this.st = State::StopIo(boxed_io.take().unwrap()); - break; - } + State::SendPayloadAndStop { body, io } => { + ready!(inner.poll_send_payload_to(cx, body, io)) + } + // shutdown io + State::Stop { fut, io } => { + if let Some(ref mut f) = fut { + let _ = ready!(Pin::new(f).poll(cx)); + fut.take(); } - } - // stop io tasks and call upgrade service - State::Upgrade(ref mut req) => { - let req = req.take().unwrap(); - let io = this.inner.io.take(); - io.stop_timer(); - log::trace!( - "{}: Switching to upgrade service for {:?}", - this.inner.io.tag(), - req - ); - - // Handle UPGRADE request - let config = this.inner.config.clone(); - let codec = this.inner.codec.clone(); - crate::rt::spawn(async move { - let _ = config - .upgrade - .as_ref() - .unwrap() - .call((req, io, codec)) - .await; - }); - return Poll::Ready(Ok(())); - } - // prepare to shutdown - State::Stop => { - this.inner.io.stop_timer(); - - return if let Err(e) = ready!(this.inner.io.poll_shutdown(cx)) { - // get io error - if let Some(e) = this.inner.error.take() { - Poll::Ready(Err(e)) + return Poll::Ready( + if let Some(io) = io { + io.stop_timer(); + ready!(io.poll_shutdown(cx)) } else { - Poll::Ready(Err(DispatchError::PeerGone(Some(e)))) + inner.io.stop_timer(); + ready!(inner.io.poll_shutdown(cx)) } - } else { - Poll::Ready(Ok(())) - }; - } - // prepare to shutdown - State::StopIo(ref item) => { - return item.0.poll_shutdown(cx).map_err(From::from) + .map_err(From::from), + ); } } } } } -impl DispatcherInner +impl DispatcherInner where - T: Filter, - S: Service, - S::Error: ResponseError + 'static, + F: Filter, + C: Service, Response = ControlAck> + 'static, + S: Service + 'static, + S::Error: ResponseError, S::Response: Into>, B: MessageBody, - X: Service, { - fn handle_error(&mut self, err: E, critical: bool) -> State - where - E: ResponseError + 'static, - { - let res: Response = (&err).into(); - let (res, body) = res.into_parts(); - let state = self.send_response(res, body.into_body()); - - // check if we can continue after error - if critical || self.payload.take().is_some() { - self.error = Some(DispatchError::Service(Box::new(err))); - if matches!(state, State::SendPayload { .. }) { - self.flags.insert(Flags::SENDPAYLOAD_AND_STOP); - state - } else { - State::Stop - } - } else { - state - } - } - - /// Handle normal requests - fn service_call(&self, req: Request) -> CallState { - CallState::Service { - fut: self.config.service.call_nowait(req), - } - } - - /// Handle filter fut - fn service_filter(&self, req: Request, f: &Pipeline) -> CallState { - CallState::Filter { - fut: f.call_nowait((req, self.io.get_ref())), - } - } - - /// Handle normal requests with EXPECT: 100-Continue` header - fn service_expect(&self, req: Request) -> CallState { - CallState::Expect { - fut: self.config.expect.call_nowait(req), - } - } - - /// Handle upgrade requests - fn service_upgrade(&mut self, mut req: Request) -> CallState { - // Move io into request - let io: IoBoxed = self.io.take().into(); - self.io.stop_timer(); - req.head_mut().io = CurrentIo::Io(Rc::new(( - io.get_ref(), - RefCell::new(Some(Box::new((io, self.codec.clone())))), - ))); - CallState::ServiceUpgrade { - fut: self.config.service.call_nowait(req), - } - } - - fn read_request( - &mut self, - cx: &mut Context<'_>, - call_state: &mut std::pin::Pin<&mut CallState>, - ) -> Poll> { + fn poll_read_request(&mut self, cx: &mut Context<'_>) -> Poll> { log::trace!("{}: Trying to read http message", self.io.tag()); - loop { - let result = match self.io.poll_recv_decode(&self.codec, cx) { - Ok(decoded) => { - if let Some(st) = self.update_hdrs_timer(&decoded) { - return Poll::Ready(st); - } - - if let Some(item) = decoded.item { - Ok(item) - } else { - return Poll::Pending; - } + let result = match self.io.poll_recv_decode(&self.codec, cx) { + Ok(decoded) => { + if let Some(st) = self.update_hdrs_timer(&decoded) { + return Poll::Ready(st); } - Err(err) => Err(err), - }; + if let Some(item) = decoded.item { + Ok(item) + } else { + return Poll::Pending; + } + } + Err(err) => Err(err), + }; - // decode incoming bytes stream - return match result { - Ok((mut req, pl)) => { - log::trace!( - "{}: Http message is received: {:?} and payload {:?}", - self.io.tag(), - req, - pl - ); + // decode incoming bytes stream + let st = match result { + Ok((mut req, pl)) => { + log::trace!( + "{}: Http message is received: {:?} and payload {:?}", + self.io.tag(), + req, + pl + ); + req.head_mut().io = CurrentIo::Ref(self.io.get_ref()); - // configure request payload - let upgrade = match pl { - PayloadType::None => false, - PayloadType::Payload(decoder) => { - let (ps, pl) = Payload::create(false); - req.replace_payload(http::Payload::H1(pl)); - self.payload = Some((decoder, ps)); - false - } - PayloadType::Stream(decoder) => { - if self.config.upgrade.is_none() { - let (ps, pl) = Payload::create(false); - req.replace_payload(http::Payload::H1(pl)); - self.payload = Some((decoder, ps)); - false - } else { - self.flags.insert(Flags::UPGRADE); - true - } - } - }; - - if upgrade { - // Handle UPGRADE request - log::trace!("{}: Prepare io for upgrade handler", self.io.tag()); - Poll::Ready(State::Upgrade(Some(req))) - } else { - if req.upgrade() { - self.flags.insert(Flags::UPGRADE_HND); - } else { - req.head_mut().io = CurrentIo::Ref(self.io.get_ref()); - } - call_state.set(if let Some(ref f) = self.config.on_request { - self.service_filter(req, f) - } else if req.head().expect() { - self.service_expect(req) - } else if self.flags.contains(Flags::UPGRADE_HND) { - self.service_upgrade(req) - } else { - self.service_call(req) - }); - Poll::Ready(State::Call) + // configure request payload + match pl { + PayloadType::None => (), + PayloadType::Payload(decoder) => { + let (ps, pl) = Payload::create(false); + req.replace_payload(http::Payload::H1(pl)); + self.payload = Some((decoder, ps)); } - } - Err(RecvError::WriteBackpressure) => { - if let Err(err) = ready!(self.io.poll_flush(cx, false)) { - log::trace!("{}: Peer is gone with {:?}", self.io.tag(), err); - self.error = Some(DispatchError::PeerGone(Some(err))); - Poll::Ready(State::Stop) - } else { - continue; + PayloadType::Stream(decoder) => { + let (ps, pl) = Payload::create(false); + req.replace_payload(http::Payload::H1(pl)); + self.payload = Some((decoder, ps)); } - } - Err(RecvError::Decoder(err)) => { - // Malformed requests, respond with 400 - log::trace!("{}: Malformed request: {:?}", self.io.tag(), err); - let (res, body) = Response::BadRequest().finish().into_parts(); - self.error = Some(DispatchError::Parse(err)); - Poll::Ready(self.send_response(res, body.into_body())) - } - Err(RecvError::PeerGone(err)) => { + }; + self.control(Control::new_req(req)) + } + Err(RecvError::WriteBackpressure) => { + if let Err(err) = ready!(self.io.poll_flush(cx, false)) { log::trace!("{}: Peer is gone with {:?}", self.io.tag(), err); - self.error = Some(DispatchError::PeerGone(err)); - Poll::Ready(State::Stop) + self.ctl_peer_gone(Some(err)) + } else { + ready!(self.poll_read_request(cx)) } - Err(RecvError::Stop) => { - log::trace!("{}: Dispatcher is instructed to stop", self.io.tag()); - Poll::Ready(State::Stop) - } - Err(RecvError::KeepAlive) => { - if self.flags.contains(Flags::READ_HDRS_TIMEOUT) { - if let Err(err) = self.handle_timeout() { - log::trace!("{}: Slow request timeout", self.io.tag()); - let (req, body) = - Response::RequestTimeout().finish().into_parts(); - let _ = self.send_response(req, body.into_body()); - self.error = Some(err); - } else { - continue; - } + } + Err(RecvError::Decoder(err)) => { + // Malformed requests, respond with 400 + log::trace!("{}: Malformed request: {:?}", self.io.tag(), err); + self.ctl_proto_err(err.into()) + } + Err(RecvError::PeerGone(err)) => { + log::trace!("{}: Peer is gone with {:?}", self.io.tag(), err); + self.ctl_peer_gone(err) + } + Err(RecvError::KeepAlive) => { + if self.flags.contains(Flags::READ_HDRS_TIMEOUT) { + if let Err(err) = self.handle_timeout() { + log::trace!("{}: Slow request timeout", self.io.tag()); + self.ctl_proto_err(err) } else { - log::trace!( - "{}: Keep-alive timeout, close connection", - self.io.tag() - ); + ready!(self.poll_read_request(cx)) } - Poll::Ready(State::Stop) + } else { + log::trace!("{}: Keep-alive timeout, close connection", self.io.tag()); + self.stop() } - }; - } + } + Err(RecvError::Stop) => { + log::trace!("{}: Dispatcher is instructed to stop", self.io.tag()); + self.stop() + } + }; + + Poll::Ready(st) } - fn send_response(&mut self, msg: Response<()>, body: ResponseBody) -> State { - trace!( + fn send_response( + &mut self, + msg: Response<()>, + body: ResponseBody, + ) -> State { + log::trace!( "{}: Sending response: {:?} body: {:?}", self.io.tag(), msg, body.size() ); + // we dont need to process responses if socket is disconnected // but we still want to handle requests with app service // so we skip response processing for droppped connection if self.io.is_closed() { - State::Stop + self.stop() } else { let result = self .io @@ -644,13 +361,14 @@ where err }); - if result.is_err() { - State::Stop - } else { - match body.size() { + match result { + Ok(()) => match body.size() { BodySize::None | BodySize::Empty => { - if self.error.is_some() { - State::Stop + if self + .flags + .intersects(Flags::DISCONNECT | Flags::SENDPAYLOAD_AND_STOP) + { + self.stop() } else if self.payload.is_some() { State::ReadPayload } else { @@ -658,48 +376,149 @@ where } } _ => State::SendPayload { body }, - } + }, + Err(_) if self.flags.contains(Flags::DISCONNECT) => self.stop(), + Err(err) => self.ctl_proto_err(err.into()), } } } - fn send_payload( + fn poll_send_payload( &mut self, - item: Option>>, - ) -> Option> { - match item { - Some(Ok(item)) => { - trace!("{}: Got response chunk: {:?}", self.io.tag(), item.len()); - match self.io.encode(Message::Chunk(Some(item)), &self.codec) { - Ok(_) => None, - Err(err) => { - self.error = Some(DispatchError::Encode(err)); - Some(State::Stop) + cx: &mut Context<'_>, + body: &mut ResponseBody, + ) -> Poll> { + if self.io.is_closed() { + return Poll::Ready(self.stop()); + } else if !self.flags.contains(Flags::SENDPAYLOAD_AND_STOP) { + if let Poll::Ready(Some(_)) = self.poll_request_payload(cx) { + self.flags.insert(Flags::SENDPAYLOAD_AND_STOP); + } + } + loop { + let _ = ready!(self.io.poll_flush(cx, false)); + let item = ready!(body.poll_next_chunk(cx)); + + let st = match item { + Some(Ok(item)) => { + log::trace!("{}: Got response chunk: {:?}", self.io.tag(), item.len()); + match self.io.encode(Message::Chunk(Some(item)), &self.codec) { + Ok(_) => continue, + Err(err) => self.ctl_proto_err(err.into()), } } + None => { + log::trace!("{}: Response payload eof {:?}", self.io.tag(), self.flags); + if let Err(err) = self.io.encode(Message::Chunk(None), &self.codec) { + self.ctl_proto_err(err.into()) + } else if self.flags.contains(Flags::DISCONNECT) { + self.stop() + } else if self.payload.is_some() { + State::ReadPayload + } else { + State::ReadRequest + } + } + Some(Err(err)) => { + log::trace!( + "{}: Error during response body poll: {:?}", + self.io.tag(), + err + ); + self.ctl_proto_err(ProtocolError::ResponsePayload(err)) + } + }; + return Poll::Ready(st); + } + } + + fn send_response_to( + &mut self, + res: Response<()>, + body: ResponseBody, + io: Option, + ) -> State { + let io = if let Some(io) = io { + io + } else if let Some((io, codec)) = res.head().io.take() { + self.codec = codec; + io + } else { + log::trace!("Handler service consumed io, stop"); + return self.stop(); + }; + + self.codec.set_ctype(ConnectionType::Close); + self.codec.unset_streaming(); + + if io + .encode(Message::Item((res, body.size())), &self.codec) + .is_ok() + { + match body.size() { + BodySize::None | BodySize::Empty => self.stop_io(io), + _ => State::SendPayloadAndStop { io, body }, } - None => { - trace!("{}: Response payload eof {:?}", self.io.tag(), self.flags); - if let Err(err) = self.io.encode(Message::Chunk(None), &self.codec) { - self.error = Some(DispatchError::Encode(err)); - Some(State::Stop) - } else if self.flags.contains(Flags::SENDPAYLOAD_AND_STOP) { - Some(State::Stop) - } else if self.payload.is_some() { - Some(State::ReadPayload) - } else { - Some(State::ReadRequest) + } else { + self.stop_io(io) + } + } + + /// send response body to specified io + fn poll_send_payload_to( + &mut self, + cx: &mut Context<'_>, + body: &mut ResponseBody, + io: &mut IoBoxed, + ) -> Poll> { + if io.is_closed() { + return Poll::Ready(self.stop()); + } else if !self.flags.contains(Flags::SENDPAYLOAD_AND_STOP) { + if let Poll::Ready(Err(_)) = self._poll_request_payload(Some(io), cx) { + self.flags.insert(Flags::SENDPAYLOAD_AND_STOP); + } + } + + loop { + let _ = ready!(io.poll_flush(cx, false)); + match ready!(body.poll_next_chunk(cx)) { + Some(Ok(item)) => { + if let Err(e) = io.encode(Message::Chunk(Some(item)), &self.codec) { + log::trace!("{}: Cannot encode chunk: {:?}", io.tag(), e); + } else { + continue; + } + } + None => { + if let Err(e) = io.encode(Message::Chunk(None), &self.codec) { + log::trace!("{}: Cannot encode payload eof: {:?}", io.tag(), e); + } + } + Some(Err(e)) => { + log::trace!("{}: error during response body poll: {:?}", io.tag(), e); } } - Some(Err(e)) => { - trace!( - "{}: Error during response body poll: {:?}", - self.io.tag(), - e - ); - self.error = Some(DispatchError::ResponsePayload(e)); - Some(State::Stop) + return Poll::Ready(self.stop_io(io.take())); + } + } + + /// we might need to read more data into a request payload + /// (ie service future can wait for payload data) + fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.payload.is_some() { + if let Some(st) = ready!(self.poll_request_payload(cx)) { + return Poll::Ready(st); } + } else if self.poll_io_closed(cx) { + // check if io is closed + return Poll::Ready(self.stop()); + } + Poll::Pending + } + + fn set_payload_error(&mut self, err: PayloadError) { + if let Some(mut payload) = self.payload.take() { + payload.1.set_error(err); } } @@ -707,23 +526,23 @@ where fn poll_request_payload( &mut self, cx: &mut Context<'_>, - ) -> Poll> { - self._poll_request_payload::(None, cx) - } - - fn set_payload_error(&mut self, err: PayloadError) { - if let Some(ref mut payload) = self.payload { - payload.1.set_error(err); - self.payload = None; + ) -> Poll>> { + if let Err(err) = ready!(self._poll_request_payload::(None, cx)) { + Poll::Ready(Some(match err { + Either::Left(e) => self.ctl_proto_err(e), + Either::Right(e) => self.ctl_peer_gone(e), + })) + } else { + Poll::Ready(None) } } /// Process request's payload - fn _poll_request_payload( + fn _poll_request_payload( &mut self, - io: Option<&Io>, + io: Option<&Io>, cx: &mut Context<'_>, - ) -> Poll> { + ) -> Poll>>> { // check if payload data is required if self.payload.is_none() { return Poll::Ready(Ok(())); @@ -772,7 +591,10 @@ where .map(|io| io.poll_flush(cx, false)) .unwrap_or_else(|| self.io.poll_flush(cx, false)); - if flush_result?.is_pending() { + if flush_result + .map_err(|e| Either::Right(Some(e)))? + .is_pending() + { break; } else { continue; @@ -780,30 +602,25 @@ where } RecvError::KeepAlive => { if let Err(err) = self.handle_timeout() { - err + Either::Left(err) } else { continue; } } RecvError::Stop => { self.set_payload_error(PayloadError::EncodingCorrupted); - io::Error::new( + Either::Right(Some(io::Error::new( io::ErrorKind::Other, "Dispatcher stopped", - ) - .into() + ))) } RecvError::PeerGone(err) => { self.set_payload_error(PayloadError::EncodingCorrupted); - if let Some(err) = err { - DispatchError::PeerGone(Some(err)) - } else { - ParseError::Incomplete.into() - } + Either::Right(err) } RecvError::Decoder(e) => { self.set_payload_error(PayloadError::EncodingCorrupted); - DispatchError::Parse(e) + Either::Left(ProtocolError::Decode(e)) } }; return Poll::Ready(Err(err)); @@ -829,7 +646,7 @@ where // wait until future completes and then close // connection self.payload = None; - Poll::Ready(Err(DispatchError::PayloadIsNotConsumed)) + Poll::Ready(Err(Either::Left(ProtocolError::PayloadIsNotConsumed))) } } } @@ -847,7 +664,7 @@ where } } - fn handle_timeout(&mut self) -> Result<(), DispatchError> { + fn handle_timeout(&mut self) -> Result<(), ProtocolError> { // check read rate if self .flags @@ -901,19 +718,16 @@ where io::ErrorKind::TimedOut, "Keep-alive", ))); - Err(DispatchError::from(io::Error::new( - io::ErrorKind::TimedOut, - "Keep-alive", - ))) + Err(ProtocolError::SlowPayloadTimeout) } else { - Err(DispatchError::SlowRequestTimeout) + Err(ProtocolError::SlowRequestTimeout) } } fn update_hdrs_timer( &mut self, decoded: &Decoded<(Request, PayloadType)>, - ) -> Option> { + ) -> Option> { // got parsed frame if decoded.item.is_some() { self.read_remains = 0; @@ -940,7 +754,7 @@ where } } else { self.io.close(); - return Some(State::Stop); + return Some(self.stop()); } } else if let Some(ref cfg) = self.config.headers_read_rate { log::debug!( @@ -976,22 +790,65 @@ where self.io.start_timer(cfg.timeout); } } + + fn publish(&self, req: Request) -> State { + State::CallPublish { + fut: self.config.service.call_nowait(req), + } + } + + fn control(&self, req: Control) -> State { + State::CallControl { + fut: self.config.control.call_nowait(req), + } + } + + fn ctl_proto_err(&self, err: ProtocolError) -> State { + State::CallControl { + fut: self.config.control.call_nowait(Control::proto_err(err)), + } + } + + fn ctl_peer_gone(&self, err: Option) -> State { + State::CallControl { + fut: self.config.control.call_nowait(Control::peer_gone(err)), + } + } + + fn ctl_upgrade(&mut self, req: Request) -> State { + let msg = Control::upgrade(req, self.io.take(), self.codec.clone()); + self.control(msg) + } + + fn stop(&mut self) -> State { + State::Stop { + io: None, + fut: Some(self.config.control.call_nowait(Control::closed())), + } + } + + fn stop_io(&mut self, io: IoBoxed) -> State { + State::Stop { + io: Some(io), + fut: Some(self.config.control.call_nowait(Control::closed())), + } + } } #[cfg(test)] mod tests { use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; - use std::{cell::Cell, future::poll_fn, io, sync::Arc}; + use std::{cell::Cell, future::poll_fn, future::Future, io, sync::Arc}; use ntex_h2::Config; use rand::Rng; use super::*; use crate::http::config::{DispatcherConfig, ServiceConfig}; - use crate::http::h1::{ClientCodec, ExpectHandler, UpgradeHandler}; + use crate::http::h1::{ClientCodec, DefaultControlService}; use crate::http::{body, Request, ResponseHead, StatusCode}; use crate::io::{self as nio, Base}; - use crate::service::{boxed, fn_service, IntoService}; + use crate::service::{fn_service, IntoService}; use crate::util::{lazy, stream_recv, Bytes, BytesMut}; use crate::{codec::Decoder, testing::Io, time::sleep, time::Millis, time::Seconds}; @@ -1001,7 +858,7 @@ mod tests { pub(crate) fn h1( stream: Io, service: F, - ) -> Dispatcher> + ) -> Dispatcher where F: IntoService, S: Service, @@ -1021,9 +878,7 @@ mod tests { Rc::new(DispatcherConfig::new( config, service.into_service(), - ExpectHandler, - None, - None, + DefaultControlService, )), ) } @@ -1036,18 +891,14 @@ mod tests { S::Response: Into>, B: MessageBody + 'static, { - crate::rt::spawn( - Dispatcher::>::new( - nio::Io::new(stream), - Rc::new(DispatcherConfig::new( - ServiceConfig::default(), - service.into_service(), - ExpectHandler, - None, - None, - )), - ), - ); + crate::rt::spawn(Dispatcher::::new( + nio::Io::new(stream), + Rc::new(DispatcherConfig::new( + ServiceConfig::default(), + service.into_service(), + DefaultControlService, + )), + )); } fn load(decoder: &mut ClientCodec, buf: &mut BytesMut) -> ResponseHead { @@ -1069,21 +920,19 @@ mod tests { Millis(5_000), Config::server(), ); - let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( + let mut h1 = Dispatcher::<_, _, _, _>::new( nio::Io::new(server), Rc::new(DispatcherConfig::new( config, fn_service(|_| { Box::pin(async { Ok::<_, io::Error>(Response::Ok().finish()) }) }), - ExpectHandler, - None, - Some(boxed::service(crate::service::into_service( - move |(req, _)| { + fn_service(move |req: Control<_, _>| { + if let Control::NewRequest(_) = req { data2.set(true); - Box::pin(async move { Ok(req) }) - }, - ))), + } + async move { Ok::<_, std::convert::Infallible>(req.ack()) } + }), )), ); sleep(Millis(50)).await; @@ -1330,7 +1179,7 @@ mod tests { fn poll_next_chunk( &mut self, _: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { let data = rand::thread_rng() .sample_iter(&rand::distributions::Alphanumeric) .take(65_536) @@ -1385,7 +1234,7 @@ mod tests { fn poll_next_chunk( &mut self, _: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { if self.0 { Poll::Pending } else { @@ -1478,14 +1327,12 @@ mod tests { Config::server(), ); config.payload_read_rate(Seconds(1), Seconds(2), 512); - let disp: Dispatcher> = Dispatcher::new( + let disp: Dispatcher = Dispatcher::new( nio::Io::new(server), Rc::new(DispatcherConfig::new( config, svc.into_service(), - ExpectHandler, - None, - None, + DefaultControlService, )), ); crate::rt::spawn(disp); diff --git a/ntex/src/http/h1/encoder.rs b/ntex/src/http/h1/encoder.rs index a4f1c0a2..ff56c7b3 100644 --- a/ntex/src/http/h1/encoder.rs +++ b/ntex/src/http/h1/encoder.rs @@ -1,13 +1,12 @@ use std::marker::PhantomData; -use std::{cell::Cell, cmp, io, io::Write, mem, ptr, ptr::copy_nonoverlapping, slice}; +use std::{cell::Cell, cmp, io::Write, mem, ptr, ptr::copy_nonoverlapping, slice}; use crate::http::body::BodySize; use crate::http::config::DateService; +use crate::http::error::EncodeError; use crate::http::header::{Value, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; -use crate::http::helpers; use crate::http::message::{ConnectionType, RequestHeadType}; -use crate::http::response::Response; -use crate::http::{HeaderMap, StatusCode, Version}; +use crate::http::{helpers, HeaderMap, Response, StatusCode, Version}; use crate::util::{BufMut, BytesMut}; const AVERAGE_HEADER_SIZE: usize = 30; @@ -48,7 +47,7 @@ pub(super) trait MessageType: Sized { fn chunked(&self) -> bool; - fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()>; + fn encode_status(&self, dst: &mut BytesMut) -> Result<(), EncodeError>; fn encode_headers( &self, @@ -57,7 +56,7 @@ pub(super) trait MessageType: Sized { mut length: BodySize, ctype: ConnectionType, timer: &DateService, - ) -> io::Result<()> { + ) -> Result<(), EncodeError> { let chunked = self.chunked(); let mut skip_len = length != BodySize::Stream; @@ -215,7 +214,7 @@ impl MessageType for Response<()> { None } - fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()> { + fn encode_status(&self, dst: &mut BytesMut) -> Result<(), EncodeError> { let head = self.head(); let reason = head.reason().as_bytes(); dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len()); @@ -244,7 +243,7 @@ impl MessageType for RequestHeadType { self.extra_headers() } - fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()> { + fn encode_status(&self, dst: &mut BytesMut) -> Result<(), EncodeError> { let head = self.as_ref(); dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE); write!( @@ -257,17 +256,20 @@ impl MessageType for RequestHeadType { Version::HTTP_09 => "HTTP/0.9", Version::HTTP_10 => "HTTP/1.0", Version::HTTP_11 => "HTTP/1.1", - _ => - return Err(io::Error::new(io::ErrorKind::Other, "unsupported version")), + _ => return Err(EncodeError::UnsupportedVersion(head.version)), } ) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + .map_err(|e| EncodeError::Internal(Box::new(e))) } } impl MessageEncoder { /// Encode message - pub(super) fn encode_chunk(&self, msg: &[u8], buf: &mut BytesMut) -> io::Result { + pub(super) fn encode_chunk( + &self, + msg: &[u8], + buf: &mut BytesMut, + ) -> Result { let mut te = self.te.get(); let result = te.encode(msg, buf); self.te.set(te); @@ -275,7 +277,7 @@ impl MessageEncoder { } /// Encode eof - pub(super) fn encode_eof(&self, buf: &mut BytesMut) -> io::Result<()> { + pub(super) fn encode_eof(&self, buf: &mut BytesMut) -> Result<(), EncodeError> { let mut te = self.te.get(); let result = te.encode_eof(buf); self.te.set(te); @@ -292,7 +294,7 @@ impl MessageEncoder { length: BodySize, ctype: ConnectionType, timer: &DateService, - ) -> io::Result<()> { + ) -> Result<(), EncodeError> { // transfer encoding if !head { self.te.set(match length { @@ -367,7 +369,11 @@ impl TransferEncoding { /// Encode message. Return `EOF` state of encoder #[inline] - pub(super) fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { + pub(super) fn encode( + &mut self, + msg: &[u8], + buf: &mut BytesMut, + ) -> Result { match self.kind { TransferEncodingKind::Eof => { let eof = msg.is_empty(); @@ -385,7 +391,7 @@ impl TransferEncoding { true } else { writeln!(helpers::Writer(buf), "{:X}\r", msg.len()) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + .map_err(|e| EncodeError::Internal(Box::new(e)))?; buf.reserve(msg.len() + 2); buf.extend_from_slice(msg); @@ -415,12 +421,12 @@ impl TransferEncoding { /// Encode eof. Return `EOF` state of encoder #[inline] - pub(super) fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> { + pub(super) fn encode_eof(&mut self, buf: &mut BytesMut) -> Result<(), EncodeError> { match self.kind { TransferEncodingKind::Eof => Ok(()), TransferEncodingKind::Length(rem) => { if rem != 0 { - Err(io::Error::new(io::ErrorKind::UnexpectedEof, "")) + Err(EncodeError::UnexpectedEof) } else { Ok(()) } diff --git a/ntex/src/http/h1/expect.rs b/ntex/src/http/h1/expect.rs deleted file mode 100644 index b994bc02..00000000 --- a/ntex/src/http/h1/expect.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::io; - -use crate::http::request::Request; -use crate::service::{Service, ServiceCtx, ServiceFactory}; - -#[derive(Copy, Clone, Debug)] -pub struct ExpectHandler; - -impl ServiceFactory for ExpectHandler { - type Response = Request; - type Error = io::Error; - type Service = ExpectHandler; - type InitError = io::Error; - - async fn create(&self, _: ()) -> Result { - Ok(ExpectHandler) - } -} - -impl Service for ExpectHandler { - type Response = Request; - type Error = io::Error; - - async fn call( - &self, - req: Request, - _: ServiceCtx<'_, Self>, - ) -> Result { - Ok(req) - } -} diff --git a/ntex/src/http/h1/mod.rs b/ntex/src/http/h1/mod.rs index 4ebcc9b5..ccdf1c93 100644 --- a/ntex/src/http/h1/mod.rs +++ b/ntex/src/http/h1/mod.rs @@ -4,20 +4,21 @@ use crate::util::{Bytes, BytesMut}; mod client; mod codec; mod decoder; +mod default; mod dispatcher; mod encoder; -mod expect; mod payload; mod service; -mod upgrade; + +pub mod control; pub use self::client::{ClientCodec, ClientPayloadCodec}; pub use self::codec::Codec; +pub use self::control::{Control, ControlAck}; pub use self::decoder::{PayloadDecoder, PayloadItem, PayloadType}; -pub use self::expect::ExpectHandler; +pub use self::default::DefaultControlService; pub use self::payload::Payload; pub use self::service::{H1Service, H1ServiceHandler}; -pub use self::upgrade::UpgradeHandler; pub(super) use self::dispatcher::Dispatcher; @@ -55,3 +56,49 @@ pub(crate) fn reserve_readbuf(src: &mut BytesMut) { src.reserve(HW - cap); } } + +#[derive(thiserror::Error, Debug)] +/// A set of errors that can occur during dispatching http requests +pub enum ProtocolError { + /// Http request parse error. + #[error("Parse error: {0}")] + Decode(#[from] super::error::DecodeError), + + /// Http response encoding error. + #[error("Encode error: {0}")] + Encode(#[from] super::error::EncodeError), + + /// Request did not complete within the specified timeout + #[error("Request did not complete within the specified timeout")] + SlowRequestTimeout, + + /// Payload did not complete within the specified timeout + #[error("Payload did not complete within the specified timeout")] + SlowPayloadTimeout, + + /// Payload is not consumed + #[error("Task is completed but request's payload is not consumed")] + PayloadIsNotConsumed, + + /// Response body processing error + #[error("Response body processing error: {0}")] + ResponsePayload(Box), +} + +impl super::ResponseError for ProtocolError { + fn error_response(&self) -> super::Response { + match self { + ProtocolError::Decode(_) => super::Response::BadRequest().into(), + + ProtocolError::SlowRequestTimeout | ProtocolError::SlowPayloadTimeout => { + super::Response::RequestTimeout().into() + } + + ProtocolError::Encode(_) + | ProtocolError::PayloadIsNotConsumed + | ProtocolError::ResponsePayload(_) => { + super::Response::InternalServerError().into() + } + } + } +} diff --git a/ntex/src/http/h1/service.rs b/ntex/src/http/h1/service.rs index 1ce36aa4..edbba8c1 100644 --- a/ntex/src/http/h1/service.rs +++ b/ntex/src/http/h1/service.rs @@ -1,27 +1,25 @@ -use std::{cell::RefCell, error::Error, fmt, marker, rc::Rc, task}; +use std::{error::Error, fmt, marker, rc::Rc, task::Context, task::Poll}; use crate::http::body::MessageBody; -use crate::http::config::{DispatcherConfig, OnRequest, ServiceConfig}; +use crate::http::config::{DispatcherConfig, ServiceConfig}; use crate::http::error::{DispatchError, ResponseError}; use crate::http::{request::Request, response::Response}; use crate::io::{types, Filter, Io}; use crate::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; -use super::codec::Codec; +use super::control::{Control, ControlAck}; +use super::default::DefaultControlService; use super::dispatcher::Dispatcher; -use super::{ExpectHandler, UpgradeHandler}; /// `ServiceFactory` implementation for HTTP1 transport -pub struct H1Service> { +pub struct H1Service { srv: S, + ctl: C, cfg: ServiceConfig, - expect: X, - upgrade: Option, - on_request: RefCell>, _t: marker::PhantomData<(F, B)>, } -impl H1Service +impl H1Service where S: ServiceFactory + 'static, S::Error: ResponseError, @@ -37,9 +35,7 @@ where H1Service { cfg, srv: service.into_factory(), - expect: ExpectHandler, - upgrade: None, - on_request: RefCell::new(None), + ctl: DefaultControlService, _t: marker::PhantomData, } } @@ -53,7 +49,7 @@ mod openssl { use super::*; use crate::{io::Layer, server::SslError}; - impl H1Service, S, B, X, U> + impl H1Service, S, B, C> where F: Filter, S: ServiceFactory + 'static, @@ -61,13 +57,10 @@ mod openssl { S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - X: ServiceFactory + 'static, - X::Error: ResponseError, - X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io>, Codec), Response = ()> + C: ServiceFactory, S::Error>, Response = ControlAck> + 'static, - U::Error: fmt::Display + Error, - U::InitError: fmt::Debug, + C::Error: Error, + C::InitError: fmt::Debug, { /// Create openssl based service pub fn openssl( @@ -98,7 +91,7 @@ mod rustls { use super::*; use crate::{io::Layer, server::SslError}; - impl H1Service, S, B, X, U> + impl H1Service, S, B, C> where F: Filter, S: ServiceFactory + 'static, @@ -106,13 +99,12 @@ mod rustls { S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - X: ServiceFactory + 'static, - X::Error: ResponseError, - X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io>, Codec), Response = ()> - + 'static, - U::Error: fmt::Display + Error, - U::InitError: fmt::Debug, + C: ServiceFactory< + Control, S::Error>, + Response = ControlAck, + > + 'static, + C::Error: Error, + C::InitError: fmt::Debug, { /// Create rustls based service pub fn rustls( @@ -133,57 +125,35 @@ mod rustls { } } -impl H1Service +impl H1Service where F: Filter, - S: ServiceFactory + 'static, + S: ServiceFactory, S::Error: ResponseError, S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, + C: ServiceFactory, Response = ControlAck>, + C::Error: Error, + C::InitError: fmt::Debug, { - pub fn expect(self, expect: X1) -> H1Service + /// Provide http/1 control service + pub fn control(self, ctl: C1) -> H1Service where - X1: ServiceFactory + 'static, - X1::Error: ResponseError, - X1::InitError: fmt::Debug, + C1: ServiceFactory, Response = ControlAck>, + C1::Error: Error, + C1::InitError: fmt::Debug, { H1Service { - expect, + ctl, cfg: self.cfg, srv: self.srv, - upgrade: self.upgrade, - on_request: self.on_request, _t: marker::PhantomData, } } - - pub fn upgrade(self, upgrade: Option) -> H1Service - where - U1: ServiceFactory<(Request, Io, Codec), Response = ()> + 'static, - U1::Error: fmt::Display + Error, - U1::InitError: fmt::Debug, - { - H1Service { - upgrade, - cfg: self.cfg, - srv: self.srv, - expect: self.expect, - on_request: self.on_request, - _t: marker::PhantomData, - } - } - - /// Set req request callback. - /// - /// It get called once per request. - pub(crate) fn on_request(self, f: Option) -> Self { - *self.on_request.borrow_mut() = f; - self - } } -impl ServiceFactory> for H1Service +impl ServiceFactory> for H1Service where F: Filter, S: ServiceFactory + 'static, @@ -191,43 +161,28 @@ where S::Response: Into>, S::InitError: fmt::Debug, B: MessageBody, - X: ServiceFactory + 'static, - X::Error: ResponseError, - X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io, Codec), Response = ()> + 'static, - U::Error: fmt::Display + Error, - U::InitError: fmt::Debug, + C: ServiceFactory, Response = ControlAck> + 'static, + C::Error: Error, + C::InitError: fmt::Debug, { type Response = (); type Error = DispatchError; type InitError = (); - type Service = H1ServiceHandler; + type Service = H1ServiceHandler; async fn create(&self, _: ()) -> Result { - let fut = self.srv.create(()); - let fut_ex = self.expect.create(()); - let fut_upg = self.upgrade.as_ref().map(|f| f.create(())); - let on_request = self.on_request.borrow_mut().take(); - let cfg = self.cfg.clone(); - - let service = fut + let service = self + .srv + .create(()) .await - .map_err(|e| log::error!("Init http service error: {:?}", e))?; - let expect = fut_ex + .map_err(|e| log::error!("Cannot construct publish service: {:?}", e))?; + let control = self + .ctl + .create(()) .await - .map_err(|e| log::error!("Init http service error: {:?}", e))?; - let upgrade = if let Some(fut) = fut_upg { - Some( - fut.await - .map_err(|e| log::error!("Init http service error: {:?}", e))?, - ) - } else { - None - }; + .map_err(|e| log::error!("Cannot construct control service: {:?}", e))?; - let config = Rc::new(DispatcherConfig::new( - cfg, service, expect, upgrade, on_request, - )); + let config = Rc::new(DispatcherConfig::new(self.cfg.clone(), service, control)); Ok(H1ServiceHandler { config, @@ -237,34 +192,38 @@ where } /// `Service` implementation for HTTP1 transport -pub struct H1ServiceHandler { - config: Rc>, +pub struct H1ServiceHandler { + config: Rc>, _t: marker::PhantomData<(F, B)>, } -impl Service> for H1ServiceHandler +impl Service> for H1ServiceHandler where F: Filter, + C: Service, Response = ControlAck> + 'static, + C::Error: Error, S: Service + 'static, S::Error: ResponseError, S::Response: Into>, B: MessageBody, - X: Service + 'static, - X::Error: ResponseError, - U: Service<(Request, Io, Codec), Response = ()> + 'static, - U::Error: fmt::Display + Error, { type Response = (); type Error = DispatchError; - fn poll_ready( - &self, - cx: &mut task::Context<'_>, - ) -> task::Poll> { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { let cfg = self.config.as_ref(); - let ready = cfg - .expect + let ready1 = cfg + .control + .poll_ready(cx) + .map_err(|e| { + log::error!("Http control service readiness error: {:?}", e); + DispatchError::Control(Box::new(e)) + })? + .is_ready(); + + let ready2 = cfg + .service .poll_ready(cx) .map_err(|e| { log::error!("Http service readiness error: {:?}", e); @@ -272,57 +231,43 @@ where })? .is_ready(); - let ready = cfg - .service - .poll_ready(cx) - .map_err(|e| { - log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(Box::new(e)) - })? - .is_ready() - && ready; - - let ready = if let Some(ref upg) = cfg.upgrade { - upg.poll_ready(cx) - .map_err(|e| { - log::error!("Http service readiness error: {:?}", e); - DispatchError::Upgrade(Box::new(e)) - })? - .is_ready() - && ready + if ready1 && ready2 { + Poll::Ready(Ok(())) } else { - ready - }; - - if ready { - task::Poll::Ready(Ok(())) - } else { - task::Poll::Pending + Poll::Pending } } - fn poll_shutdown(&self, cx: &mut task::Context<'_>) -> task::Poll<()> { - let ready = self.config.expect.poll_shutdown(cx).is_ready(); - let ready = self.config.service.poll_shutdown(cx).is_ready() && ready; - let ready = if let Some(ref upg) = self.config.upgrade { - upg.poll_shutdown(cx).is_ready() && ready - } else { - ready - }; + fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { + let ready1 = self.config.control.poll_shutdown(cx).is_ready(); + let ready2 = self.config.service.poll_shutdown(cx).is_ready(); - if ready { - task::Poll::Ready(()) + if ready1 && ready2 { + Poll::Ready(()) } else { - task::Poll::Pending + Poll::Pending } } - async fn call(&self, io: Io, _: ServiceCtx<'_, Self>) -> Result<(), DispatchError> { + async fn call(&self, io: Io, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { log::trace!( "New http1 connection, peer address {:?}", io.query::().get() ); - Dispatcher::new(io, self.config.clone()).await + let ack = self + .config + .control + .call_nowait(Control::con(io.get_ref())) + .await + .map_err(|e| DispatchError::Control(e.into()))?; + + if ack.flags.contains(super::control::ControlFlags::DISCONNECT) { + Ok(()) + } else { + Dispatcher::new(io, self.config.clone()) + .await + .map_err(DispatchError::Control) + } } } diff --git a/ntex/src/http/h1/upgrade.rs b/ntex/src/http/h1/upgrade.rs deleted file mode 100644 index a61c72eb..00000000 --- a/ntex/src/http/h1/upgrade.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::{io, marker::PhantomData}; - -use crate::http::{h1::Codec, request::Request}; -use crate::io::Io; -use crate::service::{Service, ServiceCtx, ServiceFactory}; - -pub struct UpgradeHandler(PhantomData); - -impl ServiceFactory<(Request, Io, Codec)> for UpgradeHandler { - type Response = (); - type Error = io::Error; - - type Service = UpgradeHandler; - type InitError = io::Error; - - async fn create(&self, _: ()) -> Result { - unimplemented!() - } -} - -impl Service<(Request, Io, Codec)> for UpgradeHandler { - type Response = (); - type Error = io::Error; - - async fn call( - &self, - _: (Request, Io, Codec), - _: ServiceCtx<'_, Self>, - ) -> Result { - unimplemented!() - } -} diff --git a/ntex/src/http/h2/default.rs b/ntex/src/http/h2/default.rs new file mode 100644 index 00000000..385e3f1a --- /dev/null +++ b/ntex/src/http/h2/default.rs @@ -0,0 +1,35 @@ +use std::io; + +use ntex_h2 as h2; + +use crate::http::error::H2Error; +use crate::service::{Service, ServiceCtx, ServiceFactory}; + +#[derive(Default)] +/// Default control service +pub struct DefaultControlService; + +impl ServiceFactory> for DefaultControlService { + type Response = h2::ControlResult; + type Error = io::Error; + type Service = DefaultControlService; + type InitError = io::Error; + + async fn create(&self, _: ()) -> Result { + Ok(DefaultControlService) + } +} + +impl Service> for DefaultControlService { + type Response = h2::ControlResult; + type Error = io::Error; + + async fn call( + &self, + msg: h2::ControlMessage, + _: ServiceCtx<'_, Self>, + ) -> Result { + log::trace!("HTTP/2 Control message: {:?}", msg); + Ok(msg.ack()) + } +} diff --git a/ntex/src/http/h2/mod.rs b/ntex/src/http/h2/mod.rs index df86d2bb..2d7b093c 100644 --- a/ntex/src/http/h2/mod.rs +++ b/ntex/src/http/h2/mod.rs @@ -1,7 +1,11 @@ //! HTTP/2 implementation +mod default; pub(super) mod payload; mod service; +pub use ntex_h2::{Config, ControlMessage, ControlResult}; + +pub use self::default::DefaultControlService; pub use self::payload::Payload; pub use self::service::H2Service; diff --git a/ntex/src/http/h2/service.rs b/ntex/src/http/h2/service.rs index c5acc715..3a208da5 100644 --- a/ntex/src/http/h2/service.rs +++ b/ntex/src/http/h2/service.rs @@ -1,5 +1,5 @@ use std::{cell::RefCell, io, task::Context, task::Poll}; -use std::{future::poll_fn, marker::PhantomData, mem, rc::Rc}; +use std::{error::Error, fmt, future::poll_fn, marker, mem, rc::Rc}; use ntex_h2::{self as h2, frame::StreamId, server}; @@ -14,15 +14,17 @@ use crate::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; use crate::util::{Bytes, BytesMut, HashMap}; use super::payload::{Payload, PayloadSender}; +use super::DefaultControlService; /// `ServiceFactory` implementation for HTTP2 transport -pub struct H2Service { +pub struct H2Service { srv: S, + ctl: Rc, cfg: ServiceConfig, - _t: PhantomData<(F, B)>, + _t: marker::PhantomData<(F, B)>, } -impl H2Service +impl H2Service where S: ServiceFactory, S::Error: ResponseError, @@ -37,7 +39,8 @@ where H2Service { cfg, srv: service.into_factory(), - _t: PhantomData, + ctl: Rc::new(DefaultControlService), + _t: marker::PhantomData, } } } @@ -51,13 +54,18 @@ mod openssl { use super::*; - impl H2Service, S, B> + impl H2Service, S, B, C> where F: Filter, S: ServiceFactory + 'static, S::Error: ResponseError, + S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, + C: ServiceFactory, Response = h2::ControlResult> + + 'static, + C::Error: Error, + C::InitError: fmt::Debug, { /// Create ssl based service pub fn openssl( @@ -67,7 +75,7 @@ mod openssl { Io, Response = (), Error = SslError, - InitError = S::InitError, + InitError = (), > { SslAcceptor::new(acceptor) .timeout(self.cfg.ssl_handshake_timeout) @@ -86,13 +94,18 @@ mod rustls { use super::*; use crate::{io::Layer, server::SslError}; - impl H2Service, S, B> + impl H2Service, S, B, C> where F: Filter, S: ServiceFactory + 'static, S::Error: ResponseError, + S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, + C: ServiceFactory, Response = h2::ControlResult> + + 'static, + C::Error: Error, + C::InitError: fmt::Debug, { /// Create openssl based service pub fn rustls( @@ -102,7 +115,7 @@ mod rustls { Io, Response = (), Error = SslError, - InitError = S::InitError, + InitError = (), > { let protos = vec!["h2".to_string().into()]; config.alpn_protocols = protos; @@ -116,49 +129,84 @@ mod rustls { } } -impl ServiceFactory> for H2Service +impl H2Service +where + F: Filter, + S: ServiceFactory + 'static, + S::Response: Into>, + S::Error: ResponseError, + S::InitError: fmt::Debug, + B: MessageBody, + C: ServiceFactory, Response = h2::ControlResult>, + C::Error: Error, + C::InitError: fmt::Debug, +{ + /// Provide http/2 control service + pub fn control(self, ctl: CT) -> H2Service + where + CT: ServiceFactory, Response = h2::ControlResult>, + CT::Error: Error, + CT::InitError: fmt::Debug, + { + H2Service { + ctl: Rc::new(ctl), + cfg: self.cfg, + srv: self.srv, + _t: marker::PhantomData, + } + } +} + +impl ServiceFactory> for H2Service where F: Filter, S: ServiceFactory + 'static, S::Error: ResponseError, + S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, + C: ServiceFactory, Response = h2::ControlResult> + 'static, + C::Error: Error, + C::InitError: fmt::Debug, { type Response = (); type Error = DispatchError; - type InitError = S::InitError; - type Service = H2ServiceHandler; + type InitError = (); + type Service = H2ServiceHandler; async fn create(&self, _: ()) -> Result { - let service = self.srv.create(()).await?; - let config = Rc::new(DispatcherConfig::new( - self.cfg.clone(), - service, - (), - None, - None, - )); + let service = self + .srv + .create(()) + .await + .map_err(|e| log::error!("Cannot construct publish service: {:?}", e))?; + let config = Rc::new(DispatcherConfig::new(self.cfg.clone(), service, ())); Ok(H2ServiceHandler { config, - _t: PhantomData, + control: self.ctl.clone(), + _t: marker::PhantomData, }) } } /// `Service` implementation for http/2 transport -pub struct H2ServiceHandler, B> { - config: Rc>, - _t: PhantomData<(F, B)>, +pub struct H2ServiceHandler, B, C> { + config: Rc>, + control: Rc, + _t: marker::PhantomData<(F, B)>, } -impl Service> for H2ServiceHandler +impl Service> for H2ServiceHandler where F: Filter, S: Service + 'static, S::Error: ResponseError, S::Response: Into>, B: MessageBody, + C: ServiceFactory, Response = h2::ControlResult> + 'static, + C::Error: Error, + C::InitError: fmt::Debug, { type Response = (); type Error = DispatchError; @@ -183,22 +231,28 @@ where "New http2 connection, peer address {:?}", io.query::().get() ); + let control = self.control.create(()).await.map_err(|e| { + DispatchError::Control( + format!("Cannot construct control service: {:?}", e).into(), + ) + })?; - handle(io.into(), self.config.clone()).await + handle(io.into(), control, self.config.clone()).await } } -pub(in crate::http) async fn handle( +pub(in crate::http) async fn handle( io: IoBoxed, - config: Rc>, + control: C2, + config: Rc>, ) -> Result<(), DispatchError> where S: Service + 'static, S::Error: ResponseError, S::Response: Into>, B: MessageBody, - X: 'static, - U: 'static, + C2: Service, Response = h2::ControlResult> + 'static, + C2::Error: Error, { io.set_disconnect_timeout(config.client_disconnect); let ioref = io.get_ref(); @@ -206,7 +260,7 @@ where let _ = server::handle_one( io, config.h2config.clone(), - ControlService::new(), + control, PublishService::new(ioref, config), ) .await; @@ -214,60 +268,37 @@ where Ok(()) } -struct ControlService {} - -impl ControlService { - fn new() -> Self { - Self {} - } -} - -impl Service> for ControlService { - type Response = h2::ControlResult; - type Error = (); - - async fn call( - &self, - msg: h2::ControlMessage, - _: ServiceCtx<'_, Self>, - ) -> Result { - log::trace!("Control message: {:?}", msg); - Ok::<_, ()>(msg.ack()) - } -} - -struct PublishService, B, X, U> { +struct PublishService, B, C> { io: IoRef, - config: Rc>, + config: Rc>, streams: RefCell>, - _t: PhantomData, + _t: marker::PhantomData, } -impl PublishService +impl PublishService where S: Service + 'static, S::Error: ResponseError, S::Response: Into>, B: MessageBody, { - fn new(io: IoRef, config: Rc>) -> Self { + fn new(io: IoRef, config: Rc>) -> Self { Self { io, config, streams: RefCell::new(HashMap::default()), - _t: PhantomData, + _t: marker::PhantomData, } } } -impl Service for PublishService +impl Service for PublishService where S: Service + 'static, S::Error: ResponseError, S::Response: Into>, B: MessageBody, - X: 'static, - U: 'static, + C: 'static, { type Response = (); type Error = H2Error; diff --git a/ntex/src/http/httpmessage.rs b/ntex/src/http/httpmessage.rs index ddbd30da..887dbb56 100644 --- a/ntex/src/http/httpmessage.rs +++ b/ntex/src/http/httpmessage.rs @@ -8,7 +8,7 @@ use ntex_http::header; #[cfg(feature = "cookie")] use coo_kie::Cookie; -use super::error::{ContentTypeError, ParseError}; +use super::error::{ContentTypeError, DecodeError}; use super::header::HeaderMap; use crate::util::Extensions; @@ -74,12 +74,12 @@ pub trait HttpMessage: Sized { } /// Check if request has chunked transfer encoding - fn chunked(&self) -> Result { + fn chunked(&self) -> Result { if let Some(encodings) = self.message_headers().get(header::TRANSFER_ENCODING) { if let Ok(s) = encodings.to_str() { Ok(s.to_lowercase().contains("chunked")) } else { - Err(ParseError::Header) + Err(DecodeError::Header) } } else { Ok(false) diff --git a/ntex/src/http/message.rs b/ntex/src/http/message.rs index 21e324f2..919b1cce 100644 --- a/ntex/src/http/message.rs +++ b/ntex/src/http/message.rs @@ -40,11 +40,15 @@ pub(crate) trait Head: Default + 'static { #[derive(Clone, Debug)] pub(crate) enum CurrentIo { Ref(IoRef), - Io(Rc<(IoRef, RefCell>>)>), + Io(Rc<(IoRef, RefCell>)>), None, } impl CurrentIo { + pub(crate) fn new(io: IoBoxed, codec: Codec) -> Self { + CurrentIo::Io(Rc::new((io.get_ref(), RefCell::new(Some((io, codec)))))) + } + pub(crate) fn as_ref(&self) -> Option<&IoRef> { match self { CurrentIo::Ref(ref io) => Some(io), @@ -52,6 +56,13 @@ impl CurrentIo { CurrentIo::None => None, } } + + pub(crate) fn take(&self) -> Option<(IoBoxed, Codec)> { + match self { + CurrentIo::Io(ref inner) => inner.1.borrow_mut().take(), + _ => None, + } + } } #[derive(Debug)] @@ -197,11 +208,8 @@ impl RequestHead { /// Take io and codec for current request /// /// This objects are set only for upgrade requests - pub fn take_io(&self) -> Option> { - match self.io { - CurrentIo::Io(ref inner) => inner.1.borrow_mut().take(), - _ => None, - } + pub fn take_io(&self) -> Option<(IoBoxed, Codec)> { + self.io.take() } } @@ -366,13 +374,6 @@ impl ResponseHead { pub(crate) fn set_io(&mut self, head: &RequestHead) { self.io = head.io.clone(); } - - pub(crate) fn take_io(&self) -> Option> { - match self.io { - CurrentIo::Io(ref inner) => inner.1.borrow_mut().take(), - _ => None, - } - } } impl Default for ResponseHead { diff --git a/ntex/src/http/request.rs b/ntex/src/http/request.rs index 6a4afefe..6cc32c60 100644 --- a/ntex/src/http/request.rs +++ b/ntex/src/http/request.rs @@ -171,7 +171,6 @@ impl Request { self.head.extensions_mut() } - #[allow(dead_code)] /// Split request into request head and payload pub(crate) fn into_parts(self) -> (Message, Payload) { (self.head, self.payload) diff --git a/ntex/src/http/service.rs b/ntex/src/http/service.rs index c5c61186..7af4315a 100644 --- a/ntex/src/http/service.rs +++ b/ntex/src/http/service.rs @@ -1,23 +1,28 @@ -use std::{cell, error, fmt, marker, rc::Rc, task::Context, task::Poll}; +use std::{error, fmt, marker, rc::Rc, task::Context, task::Poll}; use crate::io::{types, Filter, Io}; use crate::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; use super::body::MessageBody; use super::builder::HttpServiceBuilder; -use super::config::{DispatcherConfig, OnRequest, ServiceConfig}; -use super::error::{DispatchError, ResponseError}; +use super::config::{DispatcherConfig, ServiceConfig}; +use super::error::{DispatchError, H2Error, ResponseError}; use super::request::Request; use super::response::Response; use super::{h1, h2}; /// `ServiceFactory` HTTP1.1/HTTP2 transport implementation -pub struct HttpService> { +pub struct HttpService< + F, + S, + B, + C1 = h1::DefaultControlService, + C2 = h2::DefaultControlService, +> { srv: S, cfg: ServiceConfig, - expect: X, - upgrade: Option, - on_request: cell::RefCell>, + h1_control: C1, + h2_control: Rc, _t: marker::PhantomData<(F, B)>, } @@ -57,9 +62,8 @@ where HttpService { cfg, srv: service.into_factory(), - expect: h1::ExpectHandler, - upgrade: None, - on_request: cell::RefCell::new(None), + h1_control: h1::DefaultControlService, + h2_control: Rc::new(h2::DefaultControlService), _t: marker::PhantomData, } } @@ -72,15 +76,14 @@ where HttpService { cfg, srv: service.into_factory(), - expect: h1::ExpectHandler, - upgrade: None, - on_request: cell::RefCell::new(None), + h1_control: h1::DefaultControlService, + h2_control: Rc::new(h2::DefaultControlService), _t: marker::PhantomData, } } } -impl HttpService +impl HttpService where F: Filter, S: ServiceFactory + 'static, @@ -88,53 +91,44 @@ where S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, + C1: ServiceFactory, Response = h1::ControlAck>, + C1::Error: error::Error, + C1::InitError: fmt::Debug, + C2: ServiceFactory, Response = h2::ControlResult>, + C2::Error: error::Error, + C2::InitError: fmt::Debug, { - /// Provide service for `EXPECT: 100-Continue` support. - /// - /// Service get called with request that contains `EXPECT` header. - /// Service must return request in case of success, in that case - /// request will be forwarded to main service. - pub fn expect(self, expect: X1) -> HttpService + /// Provide http/1 control service. + pub fn h1_control(self, control: CT) -> HttpService where - X1: ServiceFactory, - X1::Error: ResponseError, - X1::InitError: fmt::Debug, + CT: ServiceFactory, Response = h1::ControlAck>, + CT::Error: error::Error, + CT::InitError: fmt::Debug, { HttpService { - expect, + h1_control: control, + h2_control: self.h2_control, cfg: self.cfg, srv: self.srv, - upgrade: self.upgrade, - on_request: self.on_request, _t: marker::PhantomData, } } - /// Provide service for custom `Connection: UPGRADE` support. - /// - /// If service is provided then normal requests handling get halted - /// and this service get called with original request and framed object. - pub fn upgrade(self, upgrade: Option) -> HttpService + /// Provide http/1 control service. + pub fn h2_control(self, control: CT) -> HttpService where - U1: ServiceFactory<(Request, Io, h1::Codec), Response = ()>, - U1::Error: fmt::Display + error::Error, - U1::InitError: fmt::Debug, + CT: ServiceFactory, Response = h2::ControlResult>, + CT::Error: error::Error, + CT::InitError: fmt::Debug, { HttpService { - upgrade, + h1_control: self.h1_control, + h2_control: Rc::new(control), cfg: self.cfg, srv: self.srv, - expect: self.expect, - on_request: self.on_request, _t: marker::PhantomData, } } - - /// Set on request callback. - pub(crate) fn on_request(self, f: Option) -> Self { - *self.on_request.borrow_mut() = f; - self - } } #[cfg(feature = "openssl")] @@ -145,7 +139,7 @@ mod openssl { use super::*; use crate::{io::Layer, server::SslError}; - impl HttpService, S, B, X, U> + impl HttpService, S, B, C1, C2> where F: Filter, S: ServiceFactory + 'static, @@ -153,13 +147,16 @@ mod openssl { S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - X: ServiceFactory + 'static, - X::Error: ResponseError, - X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io>, h1::Codec), Response = ()> + C1: ServiceFactory< + h1::Control, S::Error>, + Response = h1::ControlAck, + > + 'static, + C1::Error: error::Error, + C1::InitError: fmt::Debug, + C2: ServiceFactory, Response = h2::ControlResult> + 'static, - U::Error: fmt::Display + error::Error, - U::InitError: fmt::Debug, + C2::Error: error::Error, + C2::InitError: fmt::Debug, { /// Create openssl based service pub fn openssl( @@ -188,7 +185,7 @@ mod rustls { use super::*; use crate::{io::Layer, server::SslError}; - impl HttpService, S, B, X, U> + impl HttpService, S, B, C1, C2> where F: Filter, S: ServiceFactory + 'static, @@ -196,15 +193,16 @@ mod rustls { S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - X: ServiceFactory + 'static, - X::Error: ResponseError, - X::InitError: fmt::Debug, - U: ServiceFactory< - (Request, Io>, h1::Codec), - Response = (), + C1: ServiceFactory< + h1::Control, S::Error>, + Response = h1::ControlAck, > + 'static, - U::Error: fmt::Display + error::Error, - U::InitError: fmt::Debug, + C1::Error: error::Error, + C1::InitError: fmt::Debug, + C2: ServiceFactory, Response = h2::ControlResult> + + 'static, + C2::Error: error::Error, + C2::InitError: fmt::Debug, { /// Create openssl based service pub fn rustls( @@ -228,7 +226,7 @@ mod rustls { } } -impl ServiceFactory> for HttpService +impl ServiceFactory> for HttpService where F: Filter, S: ServiceFactory + 'static, @@ -236,68 +234,59 @@ where S::InitError: fmt::Debug, S::Response: Into>, B: MessageBody, - X: ServiceFactory + 'static, - X::Error: ResponseError, - X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io, h1::Codec), Response = ()> + 'static, - U::Error: fmt::Display + error::Error, - U::InitError: fmt::Debug, + C1: ServiceFactory, Response = h1::ControlAck> + 'static, + C1::Error: error::Error, + C1::InitError: fmt::Debug, + C2: ServiceFactory, Response = h2::ControlResult> + 'static, + C2::Error: error::Error, + C2::InitError: fmt::Debug, { type Response = (); type Error = DispatchError; type InitError = (); - type Service = HttpServiceHandler; + type Service = HttpServiceHandler; async fn create(&self, _: ()) -> Result { - let fut = self.srv.create(()); - let fut_ex = self.expect.create(()); - let fut_upg = self.upgrade.as_ref().map(|f| f.create(())); - let on_request = self.on_request.borrow_mut().take(); - let cfg = self.cfg.clone(); - - let service = fut + let service = self + .srv + .create(()) .await - .map_err(|e| log::error!("Init http service error: {:?}", e))?; - - let expect = fut_ex + .map_err(|e| log::error!("Cannot construct publish service: {:?}", e))?; + let control = self + .h1_control + .create(()) .await - .map_err(|e| log::error!("Init http service error: {:?}", e))?; + .map_err(|e| log::error!("Cannot construct control service: {:?}", e))?; - let upgrade = if let Some(fut) = fut_upg { - Some( - fut.await - .map_err(|e| log::error!("Init http service error: {:?}", e))?, - ) - } else { - None - }; - - let config = DispatcherConfig::new(cfg, service, expect, upgrade, on_request); + let config = DispatcherConfig::new(self.cfg.clone(), service, control); Ok(HttpServiceHandler { config: Rc::new(config), + h2_control: self.h2_control.clone(), _t: marker::PhantomData, }) } } /// `Service` implementation for http transport -pub struct HttpServiceHandler { - config: Rc>, +pub struct HttpServiceHandler { + config: Rc>, + h2_control: Rc, _t: marker::PhantomData<(F, B)>, } -impl Service> for HttpServiceHandler +impl Service> for HttpServiceHandler where F: Filter, S: Service + 'static, S::Error: ResponseError, S::Response: Into>, B: MessageBody, - X: Service + 'static, - X::Error: ResponseError, - U: Service<(Request, Io, h1::Codec), Response = ()> + 'static, - U::Error: fmt::Display + error::Error, + C1: Service, Response = h1::ControlAck> + 'static, + C1::Error: error::Error, + C2: ServiceFactory, Response = h2::ControlResult> + 'static, + C2::Error: error::Error, + C2::InitError: fmt::Debug, { type Response = (); type Error = DispatchError; @@ -305,8 +294,8 @@ where fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { let cfg = self.config.as_ref(); - let ready = cfg - .expect + let ready1 = cfg + .service .poll_ready(cx) .map_err(|e| { log::error!("Http service readiness error: {:?}", e); @@ -314,29 +303,16 @@ where })? .is_ready(); - let ready = cfg - .service + let ready2 = cfg + .control .poll_ready(cx) .map_err(|e| { - log::error!("Http service readiness error: {:?}", e); - DispatchError::Service(Box::new(e)) + log::error!("Http control service readiness error: {:?}", e); + DispatchError::Control(Box::new(e)) })? - .is_ready() - && ready; + .is_ready(); - let ready = if let Some(ref upg) = cfg.upgrade { - upg.poll_ready(cx) - .map_err(|e| { - log::error!("Http service readiness error: {:?}", e); - DispatchError::Upgrade(Box::new(e)) - })? - .is_ready() - && ready - } else { - ready - }; - - if ready { + if ready1 && ready2 { Poll::Ready(Ok(())) } else { Poll::Pending @@ -344,15 +320,10 @@ where } fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { - let ready = self.config.expect.poll_shutdown(cx).is_ready(); - let ready = self.config.service.poll_shutdown(cx).is_ready() && ready; - let ready = if let Some(ref upg) = self.config.upgrade { - upg.poll_shutdown(cx).is_ready() && ready - } else { - ready - }; + let ready1 = self.config.control.poll_shutdown(cx).is_ready(); + let ready2 = self.config.service.poll_shutdown(cx).is_ready(); - if ready { + if ready1 && ready2 { Poll::Ready(()) } else { Poll::Pending @@ -370,9 +341,16 @@ where ); if io.query::().get() == Some(types::HttpProtocol::Http2) { - h2::handle(io.into(), self.config.clone()).await + let control = self.h2_control.create(()).await.map_err(|e| { + DispatchError::Control( + format!("Cannot construct control service: {:?}", e).into(), + ) + })?; + h2::handle(io.into(), control, self.config.clone()).await } else { - h1::Dispatcher::new(io, self.config.clone()).await + h1::Dispatcher::new(io, self.config.clone()) + .await + .map_err(DispatchError::Control) } } } diff --git a/ntex/src/web/app.rs b/ntex/src/web/app.rs index 49f11d6e..8f5b7830 100644 --- a/ntex/src/web/app.rs +++ b/ntex/src/web/app.rs @@ -352,12 +352,13 @@ where WebRequest, Response = WebRequest, Error = Err::Container, - InitError = (), >, U: IntoServiceFactory>, { App { - filter: self.filter.and_then(filter.into_factory()), + filter: self + .filter + .and_then(filter.into_factory().map_init_err(|_| ())), middleware: self.middleware, state_factories: self.state_factories, services: self.services, diff --git a/ntex/src/web/app_service.rs b/ntex/src/web/app_service.rs index c94a010a..dff9fdf8 100644 --- a/ntex/src/web/app_service.rs +++ b/ntex/src/web/app_service.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, marker::PhantomData, rc::Rc, task::Context, task::Poll}; +use std::{cell::RefCell, marker, rc::Rc, task::Context, task::Poll}; use crate::http::{Request, Response}; use crate::router::{Path, ResourceDef, Router}; @@ -89,11 +89,12 @@ where // update resource default service let default = self.default.clone().unwrap_or_else(|| { - Rc::new(boxed::factory(fn_service( - |req: WebRequest| async move { + Rc::new(boxed::factory( + fn_service(|req: WebRequest| async move { Ok(req.into_response(Response::NotFound().finish())) - }, - ))) + }) + .map_init_err(|_| ()), + )) }); let filter_fut = self.filter.create(()); @@ -109,7 +110,9 @@ where // app state factories for fut in state_factories.iter() { - extensions = fut(extensions).await?; + extensions = fut(extensions) + .await + .map_err(|_| log::error!("Cannot initialize state factory"))? } let state = AppState::new(extensions, None, config.clone()); @@ -143,19 +146,29 @@ where // create http services for (path, factory, guards) in &mut services.iter() { - let service = factory.create(()).await?; + let service = factory + .create(()) + .await + .map_err(|_| log::error!("Cannot construct app service"))?; router.rdef(path.clone(), service).2 = guards.borrow_mut().take(); } let routing = AppRouting { router: router.finish(), - default: Some(default.create(()).await?), + default: Some( + default + .create(()) + .await + .map_err(|_| log::error!("Cannot construct default service"))?, + ), }; // main service let service = AppService { routing, - filter: filter_fut.await?, + filter: filter_fut + .await + .map_err(|_| log::error!("Cannot construct app filter"))?, }; Ok(AppFactoryService { @@ -163,7 +176,7 @@ where state, service: middleware.create(service), pool: HttpRequestPool::create(), - _t: PhantomData, + _t: marker::PhantomData, }) } } @@ -178,7 +191,7 @@ where rmap: Rc, state: AppState, pool: &'static HttpRequestPool, - _t: PhantomData, + _t: marker::PhantomData, } impl Service for AppFactoryService diff --git a/ntex/src/web/error.rs b/ntex/src/web/error.rs index 1080a44d..f9012373 100644 --- a/ntex/src/web/error.rs +++ b/ntex/src/web/error.rs @@ -8,7 +8,6 @@ pub use serde_json::error::Error as JsonError; #[cfg(feature = "url")] pub use url_pkg::ParseError as UrlParseError; -use super::{HttpRequest, HttpResponse}; use crate::http::body::Body; use crate::http::helpers::Writer; use crate::http::{error, header, StatusCode}; @@ -17,6 +16,8 @@ use crate::util::{BytesMut, Either}; pub use super::error_default::{DefaultError, Error}; pub use crate::http::error::BlockingError; +use super::{HttpRequest, HttpResponse}; + pub trait ErrorRenderer: Sized + 'static { type Container: ErrorContainer; } diff --git a/ntex/src/web/mod.rs b/ntex/src/web/mod.rs index a2f891fc..f438133b 100644 --- a/ntex/src/web/mod.rs +++ b/ntex/src/web/mod.rs @@ -125,13 +125,14 @@ pub mod dev { //! The purpose of this module is to alleviate imports of many common //! traits by adding a glob import to the top of ntex::web heavy modules: - use super::Handler; pub use crate::web::config::AppConfig; pub use crate::web::info::ConnectionInfo; pub use crate::web::rmap::ResourceMap; pub use crate::web::route::IntoRoutes; pub use crate::web::service::{WebServiceAdapter, WebServiceConfig, WebServiceFactory}; + use crate::web::Handler; + pub(crate) fn insert_slash(mut patterns: Vec) -> Vec { for path in &mut patterns { if !path.is_empty() && !path.starts_with('/') { diff --git a/ntex/src/web/resource.rs b/ntex/src/web/resource.rs index 51b295fa..f6fec488 100644 --- a/ntex/src/web/resource.rs +++ b/ntex/src/web/resource.rs @@ -252,12 +252,13 @@ where WebRequest, Response = WebRequest, Error = Err::Container, - InitError = (), >, F: IntoServiceFactory>, { Resource { - filter: self.filter.and_then(filter.into_factory()), + filter: self + .filter + .and_then(filter.into_factory().map_init_err(|_| ())), middleware: self.middleware, rdef: self.rdef, name: self.name, diff --git a/ntex/src/web/scope.rs b/ntex/src/web/scope.rs index 6ae1a7e0..22d7b321 100644 --- a/ntex/src/web/scope.rs +++ b/ntex/src/web/scope.rs @@ -314,12 +314,13 @@ where WebRequest, Response = WebRequest, Error = Err::Container, - InitError = (), >, F: IntoServiceFactory>, { Scope { - filter: self.filter.and_then(filter.into_factory()), + filter: self + .filter + .and_then(filter.into_factory().map_init_err(|_| ())), middleware: self.middleware, rdef: self.rdef, state: self.state, diff --git a/ntex/src/web/service.rs b/ntex/src/web/service.rs index 6f34157d..76bd4501 100644 --- a/ntex/src/web/service.rs +++ b/ntex/src/web/service.rs @@ -246,16 +246,12 @@ impl WebServiceAdapter { pub fn finish(self, service: F) -> impl WebServiceFactory where F: IntoServiceFactory>, - T: ServiceFactory< - WebRequest, - Response = WebResponse, - Error = Err::Container, - InitError = (), - > + 'static, + T: ServiceFactory, Response = WebResponse, Error = Err::Container> + + 'static, Err: ErrorRenderer, { WebServiceImpl { - srv: service.into_factory(), + srv: service.into_factory().map_init_err(|_| ()), rdef: self.rdef, name: self.name, guards: self.guards, diff --git a/ntex/src/web/test.rs b/ntex/src/web/test.rs index 807971e2..ad59f1d3 100644 --- a/ntex/src/web/test.rs +++ b/ntex/src/web/test.rs @@ -73,7 +73,7 @@ pub async fn init_service( where R: IntoServiceFactory, S: ServiceFactory, - S::InitError: std::fmt::Debug, + S::InitError: fmt::Debug, { let srv = app.into_factory(); srv.pipeline(AppConfig::default()).await.unwrap() diff --git a/ntex/src/ws/client.rs b/ntex/src/ws/client.rs index db038580..672cc062 100644 --- a/ntex/src/ws/client.rs +++ b/ntex/src/ws/client.rs @@ -157,19 +157,20 @@ where log::trace!("Open ws connection to {:?} addr: {:?}", head.uri, self.addr); let io = self.connector.call(msg).await?; + let tag = io.tag(); // create Framed and send request let codec = h1::ClientCodec::default(); // send request and read response let fut = async { - log::trace!("Sending ws handshake http message"); + log::trace!("{}: Sending ws handshake http message", tag); io.send( (RequestHeadType::Rc(head, Some(headers)), BodySize::None).into(), &codec, ) .await?; - log::trace!("Waiting for ws handshake response"); + log::trace!("{}: Waiting for ws handshake response", tag); io.recv(&codec) .await? .ok_or(WsClientError::Disconnected(None)) @@ -184,7 +185,7 @@ where } else { fut.await? }; - log::trace!("Ws handshake response is received {:?}", response); + log::trace!("{}: Ws handshake response is received {:?}", tag, response); // verify response if response.status != StatusCode::SWITCHING_PROTOCOLS { @@ -202,7 +203,7 @@ where false }; if !has_hdr { - log::trace!("Invalid upgrade header"); + log::trace!("{}: Invalid upgrade header", tag); return Err(WsClientError::InvalidUpgradeHeader); } @@ -210,15 +211,15 @@ where if let Some(conn) = response.headers.get(&header::CONNECTION) { if let Ok(s) = conn.to_str() { if !s.to_ascii_lowercase().contains("upgrade") { - log::trace!("Invalid connection header: {}", s); + log::trace!("{}: Invalid connection header: {}", tag, s); return Err(WsClientError::InvalidConnectionHeader(conn.clone())); } } else { - log::trace!("Invalid connection header: {:?}", conn); + log::trace!("{}: Invalid connection header: {:?}", tag, conn); return Err(WsClientError::InvalidConnectionHeader(conn.clone())); } } else { - log::trace!("Missing connection header"); + log::trace!("{}: Missing connection header", tag); return Err(WsClientError::MissingConnectionHeader); } @@ -226,7 +227,8 @@ where let encoded = ws::hash_key(key.as_ref()); if hdr_key.as_bytes() != encoded.as_bytes() { log::trace!( - "Invalid challenge response: expected: {} received: {:?}", + "{}: Invalid challenge response: expected: {} received: {:?}", + tag, encoded, key ); @@ -236,10 +238,10 @@ where )); } } else { - log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); + log::trace!("{}: Missing SEC-WEBSOCKET-ACCEPT header", tag); return Err(WsClientError::MissingWebSocketAcceptHeader); }; - log::trace!("Ws handshake response verification is completed"); + log::trace!("{}: Ws handshake response verification is completed", tag); // response and ws io Ok(WsConnection::new( @@ -295,7 +297,7 @@ impl WsClientBuilder { inner: Some(Inner { head, config, - connector: Connector::::default(), + connector: Connector::::default().tag("WS-CLIENT"), addr: None, max_size: 65_536, server_mode: false, @@ -563,7 +565,6 @@ where return Err(WsClientBuilderError::Http(e)); } - // #[allow(unused_mut)] let mut inner = self.inner.take().expect("cannot reuse WsClient builder"); // validate uri diff --git a/ntex/src/ws/error.rs b/ntex/src/ws/error.rs index c4859687..67ff46b1 100644 --- a/ntex/src/ws/error.rs +++ b/ntex/src/ws/error.rs @@ -3,7 +3,7 @@ use std::io; use thiserror::Error; -use crate::http::error::{HttpError, ParseError, ResponseError}; +use crate::http::error::{DecodeError, EncodeError, HttpError, ResponseError}; use crate::http::{header::HeaderValue, header::ALLOW, Response, StatusCode}; use crate::{connect::ConnectError, util::Either}; @@ -76,9 +76,12 @@ pub enum WsClientBuilderError { /// Websocket client error #[derive(Error, Debug)] pub enum WsClientError { + /// Invalid request + #[error("Invalid request")] + InvalidRequest(#[from] EncodeError), /// Invalid response #[error("Invalid response")] - InvalidResponse(#[from] ParseError), + InvalidResponse(#[from] DecodeError), /// Invalid response status #[error("Invalid response status")] InvalidResponseStatus(StatusCode), @@ -111,8 +114,8 @@ pub enum WsClientError { Disconnected(Option), } -impl From> for WsClientError { - fn from(err: Either) -> Self { +impl From> for WsClientError { + fn from(err: Either) -> Self { match err { Either::Left(err) => WsClientError::InvalidResponse(err), Either::Right(err) => WsClientError::Disconnected(Some(err)), @@ -120,9 +123,12 @@ impl From> for WsClientError { } } -impl From> for WsClientError { - fn from(err: Either) -> Self { - WsClientError::Disconnected(Some(err.into_inner())) +impl From> for WsClientError { + fn from(err: Either) -> Self { + match err { + Either::Left(err) => WsClientError::InvalidRequest(err), + Either::Right(err) => WsClientError::Disconnected(Some(err)), + } } } diff --git a/ntex/tests/connect.rs b/ntex/tests/connect.rs index 4da988dd..c7894159 100644 --- a/ntex/tests/connect.rs +++ b/ntex/tests/connect.rs @@ -84,19 +84,25 @@ async fn test_openssl_string() { let tcp = Some(tcp); let srv = build_test_server(move |srv| { srv.listen("test", tcp.unwrap(), |_| { - chain_factory(fn_service(|io: Io<_>| async move { - let res = io.read_ready().await; - assert!(res.is_ok()); - Ok(io) - })) + chain_factory( + fn_service(|io: Io<_>| async move { + let res = io.read_ready().await; + assert!(res.is_ok()); + Ok(io) + }) + .map_init_err(|_| ()), + ) .and_then(openssl::SslAcceptor::new(ssl_acceptor())) - .and_then(fn_service(|io: Io<_>| async move { - io.send(Bytes::from_static(b"test"), &BytesCodec) - .await - .unwrap(); - assert_eq!(io.recv(&BytesCodec).await.unwrap().unwrap(), "test"); - Ok::<_, Box>(()) - })) + .and_then( + fn_service(|io: Io<_>| async move { + io.send(Bytes::from_static(b"test"), &BytesCodec) + .await + .unwrap(); + assert_eq!(io.recv(&BytesCodec).await.unwrap().unwrap(), "test"); + Ok::<_, Box>(()) + }) + .map_init_err(|_| ()), + ) }) .unwrap() }) @@ -130,19 +136,25 @@ async fn test_openssl_read_before_error() { use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let srv = test_server(|| { - chain_factory(fn_service(|io: Io<_>| async move { - let res = io.read_ready().await; - assert!(res.is_ok()); - Ok(io) - })) + chain_factory( + fn_service(|io: Io<_>| async move { + let res = io.read_ready().await; + assert!(res.is_ok()); + Ok(io) + }) + .map_init_err(|_| ()), + ) .and_then(openssl::SslAcceptor::new(ssl_acceptor())) - .and_then(fn_service(|io: Io<_>| async move { - io.send(Bytes::from_static(b"test"), &Rc::new(BytesCodec)) - .await - .unwrap(); - time::sleep(time::Millis(100)).await; - Ok::<_, Box>(()) - })) + .and_then( + fn_service(|io: Io<_>| async move { + io.send(Bytes::from_static(b"test"), &Rc::new(BytesCodec)) + .await + .unwrap(); + time::sleep(time::Millis(100)).await; + Ok::<_, Box>(()) + }) + .map_init_err(|_| ()), + ) }); let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); @@ -171,21 +183,27 @@ async fn test_rustls_string() { use tls_rustls::{Certificate, ClientConfig}; let srv = test_server(|| { - chain_factory(fn_service(|io: Io<_>| async move { - let res = io.read_ready().await; - assert!(res.is_ok()); - Ok(io) - })) + chain_factory( + fn_service(|io: Io<_>| async move { + let res = io.read_ready().await; + assert!(res.is_ok()); + Ok(io) + }) + .map_init_err(|_| ()), + ) .and_then(rustls::TlsAcceptor::new(tls_acceptor())) - .and_then(fn_service(|io: Io<_>| async move { - assert!(io.query::().as_ref().is_none()); - assert!(io.query::().as_ref().is_none()); - io.send(Bytes::from_static(b"test"), &BytesCodec) - .await - .unwrap(); - assert_eq!(io.recv(&BytesCodec).await.unwrap().unwrap(), "test"); - Ok::<_, std::io::Error>(()) - })) + .and_then( + fn_service(|io: Io<_>| async move { + assert!(io.query::().as_ref().is_none()); + assert!(io.query::().as_ref().is_none()); + io.send(Bytes::from_static(b"test"), &BytesCodec) + .await + .unwrap(); + assert_eq!(io.recv(&BytesCodec).await.unwrap().unwrap(), "test"); + Ok::<_, std::io::Error>(()) + }) + .map_init_err(|_| ()), + ) }); let config = ClientConfig::builder() diff --git a/ntex/tests/http_awc_openssl_client.rs b/ntex/tests/http_awc_openssl_client.rs index 0b8cb037..7172d1db 100644 --- a/ntex/tests/http_awc_openssl_client.rs +++ b/ntex/tests/http_awc_openssl_client.rs @@ -8,7 +8,7 @@ use tls_openssl::ssl::{ use ntex::http::client::{Client, Connector}; use ntex::http::test::server as test_server; use ntex::http::{HttpService, Version}; -use ntex::service::{chain_factory, map_config, ServiceFactory}; +use ntex::service::{chain_factory, map_config}; use ntex::web::{self, dev::AppConfig, App, HttpResponse}; use ntex::{time::Seconds, util::Ready}; @@ -52,8 +52,7 @@ async fn test_connection_reuse_h2() { ), |_| AppConfig::default(), )) - .openssl(ssl_acceptor()) - .map_err(|_| ()), + .openssl(ssl_acceptor()), //.map_err(|_| ()), ) }); diff --git a/ntex/tests/http_openssl.rs b/ntex/tests/http_openssl.rs index 4507a361..dd45fd7a 100644 --- a/ntex/tests/http_openssl.rs +++ b/ntex/tests/http_openssl.rs @@ -12,7 +12,7 @@ use ntex::http::{body, h1, HttpService, Method, Request, Response, StatusCode, V use ntex::service::{fn_service, ServiceFactory}; use ntex::time::{sleep, timeout, Millis, Seconds}; use ntex::util::{Bytes, BytesMut, Ready}; -use ntex::{io::Io, web::error::InternalError, ws, ws::handshake_response}; +use ntex::{web::error::InternalError, ws, ws::handshake_response}; async fn load_body(stream: S) -> Result where @@ -483,26 +483,32 @@ async fn test_ssl_handshake_timeout() { async fn test_ws_transport() { let mut srv = test_server(|| { HttpService::build() - .upgrade(|(req, io, codec): (Request, Io<_>, h1::Codec)| { - async move { - let res = handshake_response(req.head()).finish(); + .h1_control(|req: h1::Control<_, _>| async move { + let ack = if let h1::Control::Upgrade(upg) = req { + upg.handle(|req, io, codec| async move { + let res = handshake_response(req.head()).finish(); - // send handshake respone - io.encode( - h1::Message::Item((res.drop_body(), body::BodySize::None)), - &codec, - ) - .unwrap(); + // send handshake respone + io.encode( + h1::Message::Item((res.drop_body(), body::BodySize::None)), + &codec, + ) + .unwrap(); - // start websocket service - let io = ws::WsTransport::create(io, ws::Codec::default()); - while let Some(item) = - io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? - { - io.send(item.freeze(), &BytesCodec).await.unwrap() - } - Ok::<_, io::Error>(()) - } + // start websocket service + let io = ws::WsTransport::create(io, ws::Codec::default()); + while let Some(item) = + io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? + { + io.send(item.freeze(), &BytesCodec).await.unwrap() + } + + Ok::<_, io::Error>(()) + }) + } else { + req.ack() + }; + Ok::<_, io::Error>(ack) }) .finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound())) .openssl(ssl_acceptor()) diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index 7fdcb9cb..abaad8a7 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -5,13 +5,15 @@ use futures_util::future::{self, FutureExt}; use futures_util::stream::{once, StreamExt}; use regex::Regex; +use ntex::http::h1::Control; use ntex::http::header::{self, HeaderName, HeaderValue}; use ntex::http::test::server as test_server; use ntex::http::{ body, HttpService, KeepAlive, Method, Request, Response, StatusCode, Version, }; +use ntex::service::fn_service; use ntex::time::{sleep, timeout, Millis, Seconds}; -use ntex::{service::fn_service, util::Bytes, util::Ready, web::error}; +use ntex::{util::Bytes, util::Ready, web::error}; #[ntex::test] async fn test_h1() { @@ -56,16 +58,21 @@ async fn test_h1_2() { async fn test_expect_continue() { let srv = test_server(|| { HttpService::build() - .expect(fn_service(|req: Request| async move { + .h1_control(fn_service(|req: Control<_, _>| async move { sleep(Millis(20)).await; - if req.head().uri.query() == Some("yes=") { - Ok(req) + let ack = if let Control::Expect(exc) = req { + if exc.get_ref().head().uri.query() == Some("yes=") { + exc.ack() + } else { + exc.fail(error::InternalError::default( + "error", + StatusCode::PRECONDITION_FAILED, + )) + } } else { - Err(error::InternalError::default( - "error", - StatusCode::PRECONDITION_FAILED, - )) - } + req.ack() + }; + Ok::<_, std::convert::Infallible>(ack) })) .keep_alive(KeepAlive::Disabled) .h1(fn_service(|mut req: Request| async move { diff --git a/ntex/tests/http_ws.rs b/ntex/tests/http_ws.rs index 26c4eeec..ce3b1bee 100644 --- a/ntex/tests/http_ws.rs +++ b/ntex/tests/http_ws.rs @@ -4,7 +4,7 @@ use ntex::codec::BytesCodec; use ntex::http::test::server as test_server; use ntex::http::{body, h1, test, HttpService, Request, Response, StatusCode}; use ntex::io::{DispatchItem, Dispatcher, Io}; -use ntex::service::{fn_factory, Service, ServiceCtx}; +use ntex::service::{Pipeline, Service, ServiceCtx}; use ntex::time::Seconds; use ntex::util::{ByteString, Bytes, Ready}; use ntex::ws::{self, handshake, handshake_response}; @@ -82,14 +82,22 @@ async fn test_simple() { let mut srv = test::server({ let ws_service = ws_service.clone(); move || { - let ws_service = ws_service.clone(); + let ws_service = Pipeline::new(ws_service.clone()); HttpService::build() .keep_alive(1) .headers_read_rate(Seconds(1), Seconds::ZERO, 16) .payload_read_rate(Seconds(1), Seconds::ZERO, 16) - .upgrade(fn_factory(move || { - Ready::Ok::<_, io::Error>(ws_service.clone()) - })) + .h1_control(move |req: h1::Control<_, _>| { + let ack = if let h1::Control::Upgrade(upg) = req { + let ws_service = ws_service.clone(); + upg.handle(|req, io, codec| async move { + ws_service.call((req, io, codec)).await + }) + } else { + req.ack() + }; + async move { Ok::<_, io::Error>(ack) } + }) .h1(|_| Ready::Ok::<_, io::Error>(Response::NotFound())) } }); @@ -249,27 +257,32 @@ async fn test_simple() { async fn test_transport() { let mut srv = test_server(|| { HttpService::build() - .upgrade(|(req, io, codec): (Request, Io, h1::Codec)| { - async move { - let res = handshake_response(req.head()).finish(); + .h1_control(move |req: h1::Control<_, _>| { + let ack = if let h1::Control::Upgrade(upg) = req { + upg.handle(|req, io, codec| async move { + let res = handshake_response(req.head()).finish(); - // send handshake respone - io.encode( - h1::Message::Item((res.drop_body(), body::BodySize::None)), - &codec, - ) - .unwrap(); + // send handshake respone + io.encode( + h1::Message::Item((res.drop_body(), body::BodySize::None)), + &codec, + ) + .unwrap(); - let io = ws::WsTransport::create(io, ws::Codec::default()); + let io = ws::WsTransport::create(io, ws::Codec::default()); - // start websocket service - while let Some(item) = - io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? - { - io.send(item.freeze(), &BytesCodec).await.unwrap() - } - Ok::<_, io::Error>(()) - } + // start websocket service + while let Some(item) = + io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? + { + io.send(item.freeze(), &BytesCodec).await.unwrap() + } + Ok::<_, io::Error>(()) + }) + } else { + req.ack() + }; + async move { Ok::<_, io::Error>(ack) } }) .finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound())) }); diff --git a/ntex/tests/http_ws_client.rs b/ntex/tests/http_ws_client.rs index e3719902..ba3aa2fa 100644 --- a/ntex/tests/http_ws_client.rs +++ b/ntex/tests/http_ws_client.rs @@ -2,8 +2,8 @@ use std::io; use ntex::codec::BytesCodec; use ntex::http::test::server as test_server; -use ntex::http::{body::BodySize, h1, HttpService, Request, Response}; -use ntex::io::{DispatchItem, Dispatcher, DispatcherConfig, Io}; +use ntex::http::{body::BodySize, h1, HttpService, Response}; +use ntex::io::{DispatchItem, Dispatcher, DispatcherConfig}; use ntex::service::{fn_factory_with_config, fn_service}; use ntex::web::{self, App, HttpRequest}; use ntex::ws::{self, handshake_response}; @@ -31,23 +31,31 @@ async fn ws_service( async fn test_simple() { let mut srv = test_server(|| { HttpService::build() - .upgrade(|(req, io, codec): (Request, Io, h1::Codec)| { - async move { - let res = handshake_response(req.head()).finish(); + .h1_control(|req: h1::Control<_, _>| async move { + let ack = if let h1::Control::Upgrade(upg) = req { + upg.handle(|req, io, codec| async move { + let res = handshake_response(req.head()).finish(); - // send handshake respone - io.encode(h1::Message::Item((res.drop_body(), BodySize::None)), &codec) + // send handshake respone + io.encode( + h1::Message::Item((res.drop_body(), BodySize::None)), + &codec, + ) .unwrap(); - // start websocket service - Dispatcher::new( - io.seal(), - ws::Codec::default(), - ws_service, - &Default::default(), - ) - .await - } + // start websocket service + Dispatcher::new( + io.seal(), + ws::Codec::default(), + ws_service, + &Default::default(), + ) + .await + }) + } else { + req.ack() + }; + Ok::<_, io::Error>(ack) }) .finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound())) }); @@ -87,23 +95,31 @@ async fn test_simple() { async fn test_transport() { let mut srv = test_server(|| { HttpService::build() - .upgrade(|(req, io, codec): (Request, Io, h1::Codec)| { - async move { - let res = handshake_response(req.head()).finish(); + .h1_control(|req: h1::Control<_, _>| async move { + let ack = if let h1::Control::Upgrade(upg) = req { + upg.handle(|req, io, codec| async move { + let res = handshake_response(req.head()).finish(); - // send handshake respone - io.encode(h1::Message::Item((res.drop_body(), BodySize::None)), &codec) + // send handshake respone + io.encode( + h1::Message::Item((res.drop_body(), BodySize::None)), + &codec, + ) .unwrap(); - // start websocket service - Dispatcher::new( - io.seal(), - ws::Codec::default(), - ws_service, - &Default::default(), - ) - .await - } + // start websocket service + Dispatcher::new( + io.seal(), + ws::Codec::default(), + ws_service, + &Default::default(), + ) + .await + }) + } else { + req.ack() + }; + Ok::<_, io::Error>(ack) }) .finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound())) }); @@ -122,19 +138,28 @@ async fn test_transport() { async fn test_keepalive_timeout() { let srv = test_server(|| { HttpService::build() - .upgrade(|(req, io, codec): (Request, Io, h1::Codec)| { - async move { - let res = handshake_response(req.head()).finish(); + .h1_control(|req: h1::Control<_, _>| async move { + let ack = if let h1::Control::Upgrade(upg) = req { + upg.handle(|req, io, codec| async move { + let res = handshake_response(req.head()).finish(); - // send handshake respone - io.encode(h1::Message::Item((res.drop_body(), BodySize::None)), &codec) + // send handshake respone + io.encode( + h1::Message::Item((res.drop_body(), BodySize::None)), + &codec, + ) .unwrap(); - // start websocket service - let cfg = DispatcherConfig::default(); - cfg.set_keepalive_timeout(Seconds::ZERO); - Dispatcher::new(io.seal(), ws::Codec::default(), ws_service, &cfg).await - } + // start websocket service + let cfg = DispatcherConfig::default(); + cfg.set_keepalive_timeout(Seconds::ZERO); + Dispatcher::new(io.seal(), ws::Codec::default(), ws_service, &cfg) + .await + }) + } else { + req.ack() + }; + Ok::<_, io::Error>(ack) }) .finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound())) });