diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index ec32ab6c..f9f7d5ad 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [0.5.0-b.3] - 2021-12-xx + +* Remove websocket support from http::client + +* Add standalone ws::client + ## [0.5.0-b.2] - 2021-12-22 * Refactor write back-pressure for http1 diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 41f57242..bffaa4fe 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.5.0-b.2" +version = "0.5.0-b.3" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" diff --git a/ntex/src/connect/rustls.rs b/ntex/src/connect/rustls.rs index bd23d9f7..3ff764bf 100644 --- a/ntex/src/connect/rustls.rs +++ b/ntex/src/connect/rustls.rs @@ -17,6 +17,15 @@ pub struct Connector { inner: TlsConnector, } +impl From> for Connector { + fn from(cfg: std::sync::Arc) -> Self { + Connector { + inner: TlsConnector::new(cfg), + connector: BaseConnector::default(), + } + } +} + impl Connector { pub fn new(config: ClientConfig) -> Self { Connector { diff --git a/ntex/src/http/client/connect.rs b/ntex/src/http/client/connect.rs index 5643dba2..6781fec8 100644 --- a/ntex/src/http/client/connect.rs +++ b/ntex/src/http/client/connect.rs @@ -1,23 +1,13 @@ use std::{future::Future, net, pin::Pin}; use crate::http::body::Body; -use crate::http::h1::ClientCodec; -use crate::http::{RequestHeadType, ResponseHead}; -use crate::io::IoBoxed; -use crate::Service; +use crate::http::RequestHeadType; +use crate::service::Service; use super::error::{ConnectError, SendRequestError}; use super::response::ClientResponse; use super::{Connect as ClientConnect, Connection}; -pub(crate) type TunnelFuture = Pin< - Box< - dyn Future< - Output = Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError>, - >, - >, ->; - pub(super) struct ConnectorWrapper(pub(crate) T); pub(super) trait Connect { @@ -27,13 +17,6 @@ pub(super) trait Connect { body: Body, addr: Option, ) -> Pin>>>; - - /// Send request, returns Response and Framed - fn open_tunnel( - &self, - head: RequestHeadType, - addr: Option, - ) -> TunnelFuture; } impl Connect for ConnectorWrapper @@ -63,23 +46,4 @@ where .map(|(head, payload)| ClientResponse::new(head, payload)) }) } - - fn open_tunnel( - &self, - head: RequestHeadType, - addr: Option, - ) -> TunnelFuture { - // connect to the host - let fut = self.0.call(ClientConnect { - uri: head.as_ref().uri.clone(), - addr, - }); - - Box::pin(async move { - let connection = fut.await?; - - // send request - connection.open_tunnel(head).await - }) - } } diff --git a/ntex/src/http/client/connection.rs b/ntex/src/http/client/connection.rs index 23f93d34..be22fce6 100644 --- a/ntex/src/http/client/connection.rs +++ b/ntex/src/http/client/connection.rs @@ -4,7 +4,6 @@ use h2::client::SendRequest; use ntex_tls::types::HttpProtocol; use crate::http::body::MessageBody; -use crate::http::h1::ClientCodec; use crate::http::message::{RequestHeadType, ResponseHead}; use crate::http::payload::Payload; use crate::io::IoBoxed; @@ -91,24 +90,4 @@ impl Connection { } } } - - /// Send request, returns Response and Framed - pub(super) async fn open_tunnel>( - mut self, - head: H, - ) -> Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError> { - match self.io.take().unwrap() { - ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await, - ConnectionType::H2(io) => { - if let Some(mut pool) = self.pool.take() { - pool.release(Connection::new( - ConnectionType::H2(io), - self.created, - None, - )); - } - Err(SendRequestError::TunnelNotSupported) - } - } - } } diff --git a/ntex/src/http/client/connector.rs b/ntex/src/http/client/connector.rs index 6ddeb7bb..1b4b5e18 100644 --- a/ntex/src/http/client/connector.rs +++ b/ntex/src/http/client/connector.rs @@ -55,7 +55,7 @@ impl Connector { let conn = Connector { connector: boxed::service( TcpConnector::new() - .map(|io| io.into_boxed()) + .map(|io| io.seal()) .map_err(ConnectError::from), ), ssl_connector: None, @@ -184,11 +184,8 @@ impl Connector { > + 'static, F: Filter, { - self.connector = boxed::service( - connector - .map(|io| io.into_boxed()) - .map_err(ConnectError::from), - ); + self.connector = + boxed::service(connector.map(|io| io.seal()).map_err(ConnectError::from)); self } @@ -203,9 +200,7 @@ impl Connector { F: Filter, { self.ssl_connector = Some(boxed::service( - connector - .map(|io| io.into_boxed()) - .map_err(ConnectError::from), + connector.map(|io| io.seal()).map_err(ConnectError::from), )); self } diff --git a/ntex/src/http/client/error.rs b/ntex/src/http/client/error.rs index c6ba2f23..126b8321 100644 --- a/ntex/src/http/client/error.rs +++ b/ntex/src/http/client/error.rs @@ -8,53 +8,7 @@ use serde_json::error::Error as JsonError; use crate::connect::openssl::{HandshakeError, SslError}; use crate::http::error::{HttpError, ParseError, PayloadError}; -use crate::http::header::HeaderValue; -use crate::http::StatusCode; use crate::util::Either; -use crate::ws::ProtocolError; - -/// Websocket client error -#[derive(Debug, Display, From)] -pub enum WsClientError { - /// Invalid response status - #[display(fmt = "Invalid response status")] - InvalidResponseStatus(StatusCode), - /// Invalid upgrade header - #[display(fmt = "Invalid upgrade header")] - InvalidUpgradeHeader, - /// Invalid connection header - #[display(fmt = "Invalid connection header")] - InvalidConnectionHeader(HeaderValue), - /// Missing CONNECTION header - #[display(fmt = "Missing CONNECTION header")] - MissingConnectionHeader, - /// Missing SEC-WEBSOCKET-ACCEPT header - #[display(fmt = "Missing SEC-WEBSOCKET-ACCEPT header")] - MissingWebSocketAcceptHeader, - /// Invalid challenge response - #[display(fmt = "Invalid challenge response")] - InvalidChallengeResponse(String, HeaderValue), - /// Protocol error - #[display(fmt = "{}", _0)] - Protocol(ProtocolError), - /// Send request error - #[display(fmt = "{}", _0)] - SendRequest(SendRequestError), -} - -impl std::error::Error for WsClientError {} - -impl From for WsClientError { - fn from(err: InvalidUrl) -> Self { - WsClientError::SendRequest(err.into()) - } -} - -impl From for WsClientError { - fn from(err: HttpError) -> Self { - WsClientError::SendRequest(err.into()) - } -} /// A set of errors that can occur during parsing json payloads #[derive(Debug, Display, From)] @@ -107,16 +61,12 @@ pub enum ConnectError { Timeout, /// Connector has been disconnected - #[display(fmt = "Internal error: connector has been disconnected")] - Disconnected, + #[display(fmt = "Connector has been disconnected")] + Disconnected(Option), /// Unresolved host name #[display(fmt = "Connector received `Connect` method with unresolved host")] Unresolved, - - /// Connection io error - #[display(fmt = "{}", _0)] - Io(io::Error), } impl std::error::Error for ConnectError {} @@ -128,7 +78,7 @@ impl From for ConnectError { crate::connect::ConnectError::NoRecords => ConnectError::NoRecords, crate::connect::ConnectError::InvalidInput => panic!(), crate::connect::ConnectError::Unresolved => ConnectError::Unresolved, - crate::connect::ConnectError::Io(e) => ConnectError::Io(e), + crate::connect::ConnectError::Io(e) => ConnectError::Disconnected(Some(e)), } } } diff --git a/ntex/src/http/client/h1proto.rs b/ntex/src/http/client/h1proto.rs index 3e4e9d2d..ed3c3198 100644 --- a/ntex/src/http/client/h1proto.rs +++ b/ntex/src/http/client/h1proto.rs @@ -82,7 +82,7 @@ where ); result } else { - return Err(SendRequestError::from(ConnectError::Disconnected)); + return Err(SendRequestError::from(ConnectError::Disconnected(None))); }; match codec.message_type() { @@ -98,22 +98,6 @@ where } } -pub(super) async fn open_tunnel( - io: IoBoxed, - head: RequestHeadType, -) -> Result<(ResponseHead, IoBoxed, h1::ClientCodec), SendRequestError> { - // create Framed and send request - let codec = h1::ClientCodec::default(); - io.send(&codec, (head, BodySize::None).into()).await?; - - // read response - if let Some(head) = io.recv(&codec).await? { - Ok((head, io, codec)) - } else { - Err(SendRequestError::from(ConnectError::Disconnected)) - } -} - /// send request body to the peer pub(super) async fn send_body( mut body: B, diff --git a/ntex/src/http/client/mod.rs b/ntex/src/http/client/mod.rs index 50c8c9b2..a72e4365 100644 --- a/ntex/src/http/client/mod.rs +++ b/ntex/src/http/client/mod.rs @@ -31,7 +31,6 @@ mod request; mod response; mod sender; mod test; -pub mod ws; pub use self::builder::ClientBuilder; pub use self::connection::Connection; @@ -193,17 +192,4 @@ impl Client { { self.request(Method::OPTIONS, url) } - - /// Construct WebSockets request. - pub fn ws(&self, url: U) -> ws::WsRequest - where - Uri: TryFrom, - >::Error: Into, - { - let mut req = ws::WsRequest::new(url, self.0.clone()); - for (key, value) in self.0.headers.iter() { - req.head.headers.insert(key.clone(), value.clone()); - } - req - } } diff --git a/ntex/src/http/client/pool.rs b/ntex/src/http/client/pool.rs index cdd87366..290bb9c4 100644 --- a/ntex/src/http/client/pool.rs +++ b/ntex/src/http/client/pool.rs @@ -136,7 +136,7 @@ where OpenConnection::spawn(key, tx, inner, connector.call(req)); match rx.await { - Err(_) => Err(ConnectError::Disconnected), + Err(_) => Err(ConnectError::Disconnected(None)), Ok(res) => res, } } @@ -148,7 +148,7 @@ where ); let rx = inner.borrow_mut().wait_for(req); match rx.await { - Err(_) => Err(ConnectError::Disconnected), + Err(_) => Err(ConnectError::Disconnected(None)), Ok(res) => res, } } @@ -536,7 +536,7 @@ mod tests { fn_service(move |req| { let (client, server) = Io::create(); store2.borrow_mut().push((req, server)); - Box::pin(async move { Ok(nio::Io::new(client).into_boxed()) }) + Box::pin(async move { Ok(nio::Io::new(client).seal()) }) }), Duration::from_secs(10), Duration::from_secs(10), diff --git a/ntex/src/http/client/response.rs b/ntex/src/http/client/response.rs index 0f8618ba..5cff32a7 100644 --- a/ntex/src/http/client/response.rs +++ b/ntex/src/http/client/response.rs @@ -63,6 +63,10 @@ impl ClientResponse { ClientResponse { head, payload } } + pub(crate) fn with_empty_payload(head: ResponseHead) -> Self { + ClientResponse::new(head, Payload::None) + } + #[inline] pub(crate) fn head(&self) -> &ResponseHead { &self.head diff --git a/ntex/src/http/client/ws.rs b/ntex/src/http/client/ws.rs deleted file mode 100644 index c603072b..00000000 --- a/ntex/src/http/client/ws.rs +++ /dev/null @@ -1,597 +0,0 @@ -//! Websockets client -use std::{convert::TryFrom, fmt, net::SocketAddr, rc::Rc, str}; - -#[cfg(feature = "cookie")] -use coo_kie::{Cookie, CookieJar}; -use nanorand::{Rng, WyRand}; - -use crate::http::error::HttpError; -use crate::http::header::{self, HeaderName, HeaderValue, AUTHORIZATION}; -use crate::http::{ConnectionType, Payload, RequestHead, StatusCode, Uri}; -use crate::io::{DispatchItem, Dispatcher, IoBoxed}; -use crate::service::{apply_fn, into_service, IntoService, Service}; -use crate::util::{sink, Either, Ready}; -use crate::{channel::mpsc, rt, time::timeout, ws}; - -pub use crate::ws::{CloseCode, CloseReason, Frame, Message}; - -use super::error::{InvalidUrl, SendRequestError, WsClientError}; -use super::response::ClientResponse; -use super::ClientConfig; - -/// `WebSocket` connection -pub struct WsRequest { - pub(crate) head: RequestHead, - err: Option, - origin: Option, - protocols: Option, - addr: Option, - max_size: usize, - server_mode: bool, - #[cfg(feature = "cookie")] - cookies: Option, - config: Rc, -} - -impl WsRequest { - /// Create new websocket connection - pub(super) fn new(uri: U, config: Rc) -> Self - where - Uri: TryFrom, - >::Error: Into, - { - let (head, err) = match Uri::try_from(uri) { - Ok(uri) => ( - RequestHead { - uri, - ..Default::default() - }, - None, - ), - Err(e) => (Default::default(), Some(e.into())), - }; - - WsRequest { - head, - err, - config, - addr: None, - origin: None, - protocols: None, - max_size: 65_536, - server_mode: false, - #[cfg(feature = "cookie")] - cookies: None, - } - } - - /// Set socket address of the server. - /// - /// This address is used for connection. If address is not - /// provided url's host name get resolved. - pub fn address(mut self, addr: SocketAddr) -> Self { - self.addr = Some(addr); - self - } - - /// Set supported websocket protocols - pub fn protocols(mut self, protos: U) -> Self - where - U: IntoIterator, - V: AsRef, - { - let mut protos = protos - .into_iter() - .fold(String::new(), |acc, s| acc + s.as_ref() + ","); - protos.pop(); - self.protocols = Some(protos); - self - } - - #[cfg(feature = "cookie")] - /// Set a cookie - pub fn cookie(mut self, cookie: Cookie<'_>) -> Self { - if self.cookies.is_none() { - let mut jar = CookieJar::new(); - jar.add(cookie.into_owned()); - self.cookies = Some(jar) - } else { - self.cookies.as_mut().unwrap().add(cookie.into_owned()); - } - self - } - - /// Set request Origin - pub fn origin(mut self, origin: V) -> Self - where - HeaderValue: TryFrom, - HttpError: From, - { - match HeaderValue::try_from(origin) { - Ok(value) => self.origin = Some(value), - Err(e) => self.err = Some(e.into()), - } - self - } - - /// Set max frame size - /// - /// By default max size is set to 64kb - pub fn max_frame_size(mut self, size: usize) -> Self { - self.max_size = size; - self - } - - /// Disable payload masking. By default ws client masks frame payload. - pub fn server_mode(mut self) -> Self { - self.server_mode = true; - self - } - - /// Append a header. - /// - /// Header gets appended to existing header. - /// To override header use `set_header()` method. - pub fn header(mut self, key: K, value: V) -> Self - where - HeaderName: TryFrom, - HeaderValue: TryFrom, - >::Error: Into, - >::Error: Into, - { - match HeaderName::try_from(key) { - Ok(key) => match HeaderValue::try_from(value) { - Ok(value) => { - self.head.headers.append(key, value); - } - Err(e) => self.err = Some(e.into()), - }, - Err(e) => self.err = Some(e.into()), - } - self - } - - /// Insert a header, replaces existing header. - pub fn set_header(mut self, key: K, value: V) -> Self - where - HeaderName: TryFrom, - HeaderValue: TryFrom, - >::Error: Into, - >::Error: Into, - { - match HeaderName::try_from(key) { - Ok(key) => match HeaderValue::try_from(value) { - Ok(value) => { - self.head.headers.insert(key, value); - } - Err(e) => self.err = Some(e.into()), - }, - Err(e) => self.err = Some(e.into()), - } - self - } - - /// Insert a header only if it is not yet set. - pub fn set_header_if_none(mut self, key: K, value: V) -> Self - where - HeaderName: TryFrom, - HeaderValue: TryFrom, - >::Error: Into, - >::Error: Into, - { - match HeaderName::try_from(key) { - Ok(key) => { - if !self.head.headers.contains_key(&key) { - match HeaderValue::try_from(value) { - Ok(value) => { - self.head.headers.insert(key, value); - } - Err(e) => self.err = Some(e.into()), - } - } - } - Err(e) => self.err = Some(e.into()), - } - self - } - - /// Set HTTP basic authorization header - pub fn basic_auth(self, username: U, password: Option<&str>) -> Self - where - U: fmt::Display, - { - let auth = match password { - Some(password) => format!("{}:{}", username, password), - None => format!("{}:", username), - }; - self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth))) - } - - /// Set HTTP bearer authentication header - pub fn bearer_auth(self, token: T) -> Self - where - T: fmt::Display, - { - self.header(AUTHORIZATION, format!("Bearer {}", token)) - } - - /// Complete request construction and connect to a websockets server. - pub async fn connect(mut self) -> Result { - if let Some(e) = self.err.take() { - return Err(WsClientError::from(e)); - } - - // validate uri - let uri = &self.head.uri; - if uri.host().is_none() { - return Err(InvalidUrl::MissingHost.into()); - } else if uri.scheme().is_none() { - return Err(InvalidUrl::MissingScheme.into()); - } else if let Some(scheme) = uri.scheme() { - match scheme.as_str() { - "http" | "ws" | "https" | "wss" => (), - _ => return Err(InvalidUrl::UnknownScheme.into()), - } - } else { - return Err(InvalidUrl::UnknownScheme.into()); - } - - if !self.head.headers.contains_key(header::HOST) { - self.head.headers.insert( - header::HOST, - HeaderValue::from_str(uri.host().unwrap()).unwrap(), - ); - } - - #[cfg(feature = "cookie")] - { - use percent_encoding::percent_encode; - use std::fmt::Write as FmtWrite; - - // set cookies - if let Some(ref mut jar) = self.cookies { - let mut cookie = String::new(); - for c in jar.delta() { - let name = percent_encode( - c.name().as_bytes(), - crate::http::helpers::USERINFO, - ); - let value = percent_encode( - c.value().as_bytes(), - crate::http::helpers::USERINFO, - ); - let _ = write!(&mut cookie, "; {}={}", name, value); - } - self.head.headers.insert( - header::COOKIE, - HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), - ); - } - } - - // origin - if let Some(origin) = self.origin.take() { - self.head.headers.insert(header::ORIGIN, origin); - } - - self.head.set_connection_type(ConnectionType::Upgrade); - self.head - .headers - .insert(header::UPGRADE, HeaderValue::from_static("websocket")); - self.head.headers.insert( - header::SEC_WEBSOCKET_VERSION, - HeaderValue::from_static("13"), - ); - - if let Some(protocols) = self.protocols.take() { - self.head.headers.insert( - header::SEC_WEBSOCKET_PROTOCOL, - HeaderValue::try_from(protocols.as_str()).unwrap(), - ); - } - - // Generate a random key for the `Sec-WebSocket-Key` header. - // a base64-encoded (see Section 4 of [RFC4648]) value that, - // when decoded, is 16 bytes in length (RFC 6455) - let mut sec_key: [u8; 16] = [0; 16]; - WyRand::new().fill(&mut sec_key); - let key = base64::encode(&sec_key); - - self.head.headers.insert( - header::SEC_WEBSOCKET_KEY, - HeaderValue::try_from(key.as_str()).unwrap(), - ); - - let head = self.head; - let max_size = self.max_size; - let server_mode = self.server_mode; - - let fut = self.config.connector.open_tunnel(head.into(), self.addr); - - // set request timeout - let (head, io, _) = if self.config.timeout.non_zero() { - timeout(self.config.timeout, fut) - .await - .map_err(|_| SendRequestError::Timeout) - .and_then(|res| res)? - } else { - fut.await? - }; - - // verify response - if head.status != StatusCode::SWITCHING_PROTOCOLS { - return Err(WsClientError::InvalidResponseStatus(head.status)); - } - - // Check for "UPGRADE" to websocket header - let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) { - if let Ok(s) = hdr.to_str() { - s.to_ascii_lowercase().contains("websocket") - } else { - false - } - } else { - false - }; - if !has_hdr { - log::trace!("Invalid upgrade header"); - return Err(WsClientError::InvalidUpgradeHeader); - } - - // Check for "CONNECTION" header - if let Some(conn) = head.headers.get(&header::CONNECTION) { - if let Ok(s) = conn.to_str() { - if !s.to_ascii_lowercase().contains("upgrade") { - log::trace!("Invalid connection header: {}", s); - return Err(WsClientError::InvalidConnectionHeader(conn.clone())); - } - } else { - log::trace!("Invalid connection header: {:?}", conn); - return Err(WsClientError::InvalidConnectionHeader(conn.clone())); - } - } else { - log::trace!("Missing connection header"); - return Err(WsClientError::MissingConnectionHeader); - } - - if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) { - let encoded = ws::hash_key(key.as_ref()); - if hdr_key.as_bytes() != encoded.as_bytes() { - log::trace!( - "Invalid challenge response: expected: {} received: {:?}", - encoded, - key - ); - return Err(WsClientError::InvalidChallengeResponse( - encoded, - hdr_key.clone(), - )); - } - } else { - log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); - return Err(WsClientError::MissingWebSocketAcceptHeader); - }; - - // response and ws io - Ok(WsConnection::new( - ClientResponse::new(head, Payload::None), - io, - if server_mode { - ws::Codec::new().max_size(max_size) - } else { - ws::Codec::new().max_size(max_size).client_mode() - }, - )) - } -} - -impl fmt::Debug for WsRequest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "\nWebsocketsRequest {}:{}", - self.head.method, self.head.uri - )?; - writeln!(f, " headers:")?; - for (key, val) in self.head.headers.iter() { - writeln!(f, " {:?}: {:?}", key, val)?; - } - Ok(()) - } -} - -pub struct WsConnection { - io: IoBoxed, - codec: ws::Codec, - res: ClientResponse, -} - -impl WsConnection { - fn new(res: ClientResponse, io: IoBoxed, codec: ws::Codec) -> Self { - Self { io, codec, res } - } - - /// Get ws sink - pub fn sink(&self) -> ws::WsSink { - ws::WsSink::new(self.io.get_ref(), self.codec.clone()) - } - - /// Get reference to response - pub fn response(&self) -> &ClientResponse { - &self.res - } - - // TODO: fix close frame handling - /// Start client websockets with `SinkService` and `mpsc::Receiver` - pub fn start_default(self) -> mpsc::Receiver>> { - let (tx, rx): (_, mpsc::Receiver>>) = - mpsc::channel(); - - rt::spawn(async move { - let io = self.io.get_ref(); - let srv = sink::SinkService::new(tx.clone()).map(|_| None); - - if let Err(err) = self - .start(into_service(move |item| { - let io = io.clone(); - let close = matches!(item, ws::Frame::Close(_)); - let fut = srv.call(Ok::<_, ws::WsError<()>>(item)); - async move { - let result = fut.await.map_err(|_| ()); - if close { - io.close(); - } - result - } - })) - .await - { - let _ = tx.send(Err(err)); - } - }); - - rx - } - - /// Start client websockets service. - pub async fn start(self, service: U) -> Result<(), ws::WsError> - where - T: Service> + 'static, - U: IntoService, - { - let service = apply_fn( - service.into_service().map_err(ws::WsError::Service), - |req, srv| match req { - DispatchItem::Item(item) => Either::Left(srv.call(item)), - DispatchItem::WBackPressureEnabled - | DispatchItem::WBackPressureDisabled => Either::Right(Ready::Ok(None)), - DispatchItem::KeepAliveTimeout => { - Either::Right(Ready::Err(ws::WsError::KeepAlive)) - } - DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => { - Either::Right(Ready::Err(ws::WsError::Protocol(e))) - } - DispatchItem::Disconnect(Some(e)) => { - Either::Right(Ready::Err(ws::WsError::Io(e))) - } - DispatchItem::Disconnect(None) => { - Either::Right(Ready::Err(ws::WsError::Disconnected)) - } - }, - ); - - Dispatcher::new(self.io, self.codec, service, Default::default()).await - } - - /// Consumes the `WsConnection`, returning it'as underlying I/O framed object - /// and response. - pub fn into_inner(self) -> (ClientResponse, IoBoxed, ws::Codec) { - (self.res, self.io, self.codec) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::http::client::Client; - - #[crate::rt_test] - async fn test_debug() { - let request = Client::new().ws("/").header("x-test", "111"); - let repr = format!("{:?}", request); - assert!(repr.contains("WebsocketsRequest")); - assert!(repr.contains("x-test")); - } - - #[crate::rt_test] - async fn test_header_override() { - let req = Client::build() - .header(header::CONTENT_TYPE, "111") - .finish() - .ws("/") - .set_header(header::CONTENT_TYPE, "222"); - - assert_eq!( - req.head - .headers - .get(header::CONTENT_TYPE) - .unwrap() - .to_str() - .unwrap(), - "222" - ); - } - - #[crate::rt_test] - async fn basic_auth() { - let req = Client::new() - .ws("/") - .basic_auth("username", Some("password")); - assert_eq!( - req.head - .headers - .get(header::AUTHORIZATION) - .unwrap() - .to_str() - .unwrap(), - "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" - ); - - let req = Client::new().ws("/").basic_auth("username", None); - assert_eq!( - req.head - .headers - .get(header::AUTHORIZATION) - .unwrap() - .to_str() - .unwrap(), - "Basic dXNlcm5hbWU6" - ); - } - - #[crate::rt_test] - async fn bearer_auth() { - let req = Client::new().ws("/").bearer_auth("someS3cr3tAutht0k3n"); - assert_eq!( - req.head - .headers - .get(header::AUTHORIZATION) - .unwrap() - .to_str() - .unwrap(), - "Bearer someS3cr3tAutht0k3n" - ); - let _ = req.connect(); - } - - #[cfg(feature = "cookie")] - #[crate::rt_test] - async fn basics() { - let req = Client::new() - .ws("http://localhost/") - .origin("test-origin") - .max_frame_size(100) - .server_mode() - .protocols(&["v1", "v2"]) - .set_header_if_none(header::CONTENT_TYPE, "json") - .set_header_if_none(header::CONTENT_TYPE, "text") - .cookie(Cookie::build("cookie1", "value1").finish()); - assert_eq!( - req.origin.as_ref().unwrap().to_str().unwrap(), - "test-origin" - ); - assert_eq!(req.max_size, 100); - assert_eq!(req.server_mode, true); - assert_eq!(req.protocols, Some("v1,v2".to_string())); - assert_eq!( - req.head.headers.get(header::CONTENT_TYPE).unwrap(), - header::HeaderValue::from_static("json") - ); - - let _ = req.connect().await; - - assert!(Client::new().ws("/").connect().await.is_err()); - assert!(Client::new().ws("http:///test").connect().await.is_err()); - assert!(Client::new().ws("hmm://test.com/").connect().await.is_err()); - } -} diff --git a/ntex/src/http/error.rs b/ntex/src/http/error.rs index 10f56c41..fe6017fe 100644 --- a/ntex/src/http/error.rs +++ b/ntex/src/http/error.rs @@ -86,10 +86,6 @@ pub enum ParseError { Timeout, /// An `InvalidInput` occurred while trying to parse incoming stream. InvalidInput(&'static str), - /// An `io::Error` that occurred while trying to read or write to a network - /// stream. - #[display(fmt = "IO error: {}", _0)] - Io(io::Error), /// Parsing a field as string failed #[display(fmt = "UTF8 error: {}", _0)] Utf8(Utf8Error), @@ -321,21 +317,8 @@ mod tests { }; } - macro_rules! from_and_cause { - ($from:expr => $error:pat) => { - match ParseError::from($from) { - e @ $error => { - let desc = format!("{}", e); - assert_eq!(desc, format!("IO error: {}", $from)); - } - _ => unreachable!("{:?}", $from), - } - }; - } - #[test] fn test_from() { - from_and_cause!(io::Error::new(io::ErrorKind::Other, "other") => ParseError::Io(..)); from!(httparse::Error::HeaderName => ParseError::Header); from!(httparse::Error::HeaderName => ParseError::Header); from!(httparse::Error::HeaderValue => ParseError::Header); diff --git a/ntex/src/http/h1/decoder.rs b/ntex/src/http/h1/decoder.rs index bf21381a..ec135c95 100644 --- a/ntex/src/http/h1/decoder.rs +++ b/ntex/src/http/h1/decoder.rs @@ -674,7 +674,6 @@ impl ChunkedState { #[cfg(test)] mod tests { use super::*; - use crate::http::error::ParseError; use crate::http::header::{HeaderName, SET_COOKIE}; use crate::http::{HttpMessage, Method, Version}; use crate::util::{Bytes, BytesMut}; @@ -723,11 +722,7 @@ mod tests { macro_rules! expect_parse_err { ($e:expr) => {{ match MessageDecoder::::default().decode($e) { - Err(err) => { - if let ParseError::Io(_) = err { - unreachable!("Parse error expected") - } - } + Err(_) => (), _ => unreachable!("Error expected"), } }}; diff --git a/ntex/src/http/message.rs b/ntex/src/http/message.rs index 4ae8c92b..c5b956d0 100644 --- a/ntex/src/http/message.rs +++ b/ntex/src/http/message.rs @@ -45,7 +45,7 @@ pub struct RequestHead { pub headers: HeaderMap, pub extensions: RefCell, pub io: Option, - pub(super) flags: Flags, + pub(crate) flags: Flags, } impl Default for RequestHead { diff --git a/ntex/src/http/test.rs b/ntex/src/http/test.rs index 00af307d..c5c8f156 100644 --- a/ntex/src/http/test.rs +++ b/ntex/src/http/test.rs @@ -4,13 +4,11 @@ use std::{convert::TryFrom, io, net, str::FromStr, sync::mpsc, thread}; #[cfg(feature = "cookie")] use coo_kie::{Cookie, CookieJar}; -use crate::{io::Io, rt::System, server::Server, service::ServiceFactory}; +use crate::ws::{error::WsClientError, WsClient, WsConnection}; +use crate::{io::Filter, io::Io, rt::System, server::Server, service::ServiceFactory}; use crate::{time::Millis, time::Seconds, util::Bytes}; -use super::client::error::WsClientError; -use super::client::{ - ws::WsConnection, Client, ClientRequest, ClientResponse, Connector, -}; +use super::client::{Client, ClientRequest, ClientResponse, Connector}; use super::error::{HttpError, PayloadError}; use super::header::{HeaderMap, HeaderName, HeaderValue}; use super::payload::Payload; @@ -322,14 +320,54 @@ impl TestServer { response.body().limit(10_485_760).await } - /// Connect to websocket server at a given path - pub async fn ws_at(&mut self, path: &str) -> Result { - self.client.ws(self.url(path)).connect().await + /// Connect to a websocket server + pub async fn ws(&mut self) -> Result, WsClientError> { + self.ws_at("/").await } + /// Connect to websocket server at a given path + pub async fn ws_at( + &mut self, + path: &str, + ) -> Result, WsClientError> { + WsClient::build(self.url(path)) + .address(self.addr) + .timeout(Seconds(30)) + .finish() + .unwrap() + .connect() + .await + } + + #[cfg(feature = "openssl")] /// Connect to a websocket server - pub async fn ws(&mut self) -> Result { - self.ws_at("/").await + pub async fn wss(&mut self) -> Result, WsClientError> { + self.wss_at("/").await + } + + #[cfg(feature = "openssl")] + /// Connect to secure websocket server at a given path + pub async fn wss_at( + &mut self, + path: &str, + ) -> Result, WsClientError> { + use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x08http/1.1") + .map_err(|e| log::error!("Cannot set alpn protocol: {:?}", e)); + + WsClient::build(self.url(path)) + .address(self.addr) + .timeout(Seconds(30)) + .openssl(builder.build()) + .take() + .finish() + .unwrap() + .connect() + .await } /// Stop http server diff --git a/ntex/src/http/ws.rs b/ntex/src/http/ws.rs index 9b964d97..9ae5fb4a 100644 --- a/ntex/src/http/ws.rs +++ b/ntex/src/http/ws.rs @@ -54,7 +54,7 @@ impl ResponseError for HandshakeError { } } -impl ResponseError for crate::ws::ProtocolError {} +impl ResponseError for crate::ws::error::ProtocolError {} /// Verify `WebSocket` handshake request and create handshake reponse. // /// `protocols` is a sequence of known protocols. On successful handshake, diff --git a/ntex/src/web/test.rs b/ntex/src/web/test.rs index 2c17dc1f..3998d72a 100644 --- a/ntex/src/web/test.rs +++ b/ntex/src/web/test.rs @@ -10,8 +10,7 @@ use serde::de::DeserializeOwned; use serde::Serialize; use crate::http::body::MessageBody; -use crate::http::client::error::WsClientError; -use crate::http::client::{ws, Client, ClientRequest, ClientResponse, Connector}; +use crate::http::client::{Client, ClientRequest, ClientResponse, Connector}; use crate::http::error::{HttpError, PayloadError, ResponseError}; use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE}; use crate::http::test::TestRequest as HttpTestRequest; @@ -22,7 +21,8 @@ use crate::service::{ }; use crate::time::{sleep, Millis, Seconds}; use crate::util::{next, Bytes, BytesMut, Extensions, Ready}; -use crate::{rt::System, server::Server, Stream}; +use crate::ws::{error::WsClientError, WsClient, WsConnection}; +use crate::{io::Sealed, rt::System, server::Server, Stream}; use crate::web::config::AppConfig; use crate::web::error::{DefaultError, ErrorRenderer}; @@ -919,12 +919,50 @@ impl TestServer { } /// Connect to websocket server at a given path - pub async fn ws_at(&self, path: &str) -> Result { - self.client.ws(self.url(path)).connect().await + pub async fn ws_at( + &self, + path: &str, + ) -> Result, WsClientError> { + if self.ssl { + #[cfg(feature = "openssl")] + { + use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let _ = builder + .set_alpn_protos(b"\x08http/1.1") + .map_err(|e| log::error!("Cannot set alpn protocol: {:?}", e)); + + WsClient::build(self.url(path)) + .address(self.addr) + .timeout(Seconds(30)) + .openssl(builder.build()) + .take() + .finish() + .unwrap() + .connect() + .await + .map(|ws| ws.seal()) + } + #[cfg(not(feature = "openssl"))] + { + panic!("openssl feature is required") + } + } else { + WsClient::build(self.url(path)) + .address(self.addr) + .timeout(Seconds(30)) + .finish() + .unwrap() + .connect() + .await + .map(|ws| ws.seal()) + } } /// Connect to a websocket server - pub async fn ws(&self) -> Result { + pub async fn ws(&self) -> Result, WsClientError> { self.ws_at("/").await } diff --git a/ntex/src/ws/client.rs b/ntex/src/ws/client.rs new file mode 100644 index 00000000..2cbf2e2c --- /dev/null +++ b/ntex/src/ws/client.rs @@ -0,0 +1,936 @@ +//! Websockets client +use std::future::Future; +use std::{cell::RefCell, convert::TryFrom, fmt, marker, net, rc::Rc, str}; + +#[cfg(feature = "openssl")] +use crate::connect::openssl; +#[cfg(feature = "rustls")] +use crate::connect::rustls; +#[cfg(feature = "cookie")] +use coo_kie::{Cookie, CookieJar}; + +use nanorand::{Rng, WyRand}; + +use crate::connect::{Connect, ConnectError, Connector}; +use crate::http::header::{self, HeaderMap, HeaderName, HeaderValue, AUTHORIZATION}; +use crate::http::{body::BodySize, client::ClientResponse, error::HttpError, h1}; +use crate::http::{ConnectionType, RequestHead, RequestHeadType, StatusCode, Uri}; +use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Sealed}; +use crate::service::{apply_fn, into_service, IntoService, Service}; +use crate::util::{sink, Either, Ready}; +use crate::{channel::mpsc, rt, time::timeout, time::Millis, ws}; + +use super::error::{WsClientBuilderError, WsClientError, WsError}; + +/// `WebSocket` client builder +pub struct WsClient { + connector: T, + head: Rc, + addr: Option, + max_size: usize, + server_mode: bool, + timeout: Millis, + extra_headers: RefCell>, + _t: marker::PhantomData, +} + +/// `WebSocket` client builder +pub struct WsClientBuilder { + inner: Option>, + err: Option, + protocols: Option, + origin: Option, + #[cfg(feature = "cookie")] + cookies: Option, +} + +struct Inner { + connector: T, + pub(crate) head: RequestHead, + addr: Option, + max_size: usize, + server_mode: bool, + timeout: Millis, + _t: marker::PhantomData, +} + +impl WsClient { + /// Create new websocket client builder + pub fn build( + uri: U, + ) -> WsClientBuilder< + Base, + impl Service, Response = Io, Error = ConnectError>, + > + where + Uri: TryFrom, + >::Error: Into, + { + WsClientBuilder::new(uri) + } + + /// Create new websocket client builder + pub fn with_connector(uri: U, connector: T) -> WsClientBuilder + where + Uri: TryFrom, + >::Error: Into, + F: Filter, + T: Service, Response = Io, Error = ConnectError>, + { + WsClientBuilder::new(uri).connector(connector) + } +} + +impl WsClient { + /// Insert a header, replaces existing header. + pub fn set_header(&self, key: K, value: V) -> Result<(), HttpError> + where + HeaderName: TryFrom, + HeaderValue: TryFrom, + >::Error: Into, + >::Error: Into, + { + let key = HeaderName::try_from(key).map_err(Into::into)?; + let value = HeaderValue::try_from(value).map_err(Into::into)?; + if let Some(headers) = &mut *self.extra_headers.borrow_mut() { + headers.insert(key, value); + return Ok(()); + } + let mut headers = HeaderMap::new(); + headers.insert(key, value); + *self.extra_headers.borrow_mut() = Some(headers); + Ok(()) + } + + /// Set HTTP basic authorization header + pub fn set_basic_auth( + &self, + username: U, + password: Option<&str>, + ) -> Result<(), HttpError> + where + U: fmt::Display, + { + let auth = match password { + Some(password) => format!("{}:{}", username, password), + None => format!("{}:", username), + }; + self.set_header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth))) + } + + /// Set HTTP bearer authentication header + pub fn set_bearer_auth(&self, token: U) -> Result<(), HttpError> + where + U: fmt::Display, + { + self.set_header(AUTHORIZATION, format!("Bearer {}", token)) + } +} + +impl WsClient +where + F: Filter, + T: Service, Response = Io, Error = ConnectError>, +{ + /// Complete request construction and connect to a websockets server. + pub fn connect( + &self, + ) -> impl Future, WsClientError>> { + let head = self.head.clone(); + let max_size = self.max_size; + let server_mode = self.server_mode; + let to = self.timeout; + let mut headers = self + .extra_headers + .borrow_mut() + .take() + .unwrap_or_else(HeaderMap::new); + + // Generate a random key for the `Sec-WebSocket-Key` header. + // a base64-encoded (see Section 4 of [RFC4648]) value that, + // when decoded, is 16 bytes in length (RFC 6455) + let mut sec_key: [u8; 16] = [0; 16]; + WyRand::new().fill(&mut sec_key); + let key = base64::encode(&sec_key); + + headers.insert( + header::SEC_WEBSOCKET_KEY, + HeaderValue::try_from(key.as_str()).unwrap(), + ); + + let msg = Connect::new(head.uri.clone()).set_addr(self.addr); + let fut = self.connector.call(msg); + + async move { + let io = fut.await?; + + // create Framed and send request + let codec = h1::ClientCodec::default(); + + // send request and read response + let fut = async { + io.send( + &codec, + (RequestHeadType::Rc(head, Some(headers)), BodySize::None).into(), + ) + .await?; + io.recv(&codec) + .await? + .ok_or(WsClientError::Disconnected(None)) + }; + + // set request timeout + let response = if to.non_zero() { + timeout(to, fut) + .await + .map_err(|_| WsClientError::Timeout) + .and_then(|res| res)? + } else { + fut.await? + }; + + // verify response + if response.status != StatusCode::SWITCHING_PROTOCOLS { + return Err(WsClientError::InvalidResponseStatus(response.status)); + } + + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = response.headers.get(&header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + log::trace!("Invalid upgrade header"); + return Err(WsClientError::InvalidUpgradeHeader); + } + + // Check for "CONNECTION" header + if let Some(conn) = response.headers.get(&header::CONNECTION) { + if let Ok(s) = conn.to_str() { + if !s.to_ascii_lowercase().contains("upgrade") { + log::trace!("Invalid connection header: {}", s); + return Err(WsClientError::InvalidConnectionHeader( + conn.clone(), + )); + } + } else { + log::trace!("Invalid connection header: {:?}", conn); + return Err(WsClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + log::trace!("Missing connection header"); + return Err(WsClientError::MissingConnectionHeader); + } + + if let Some(hdr_key) = response.headers.get(&header::SEC_WEBSOCKET_ACCEPT) { + let encoded = ws::hash_key(key.as_ref()); + if hdr_key.as_bytes() != encoded.as_bytes() { + log::trace!( + "Invalid challenge response: expected: {} received: {:?}", + encoded, + key + ); + return Err(WsClientError::InvalidChallengeResponse( + encoded, + hdr_key.clone(), + )); + } + } else { + log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); + return Err(WsClientError::MissingWebSocketAcceptHeader); + }; + + // response and ws io + Ok(WsConnection::new( + io, + ClientResponse::with_empty_payload(response), + if server_mode { + ws::Codec::new().max_size(max_size) + } else { + ws::Codec::new().max_size(max_size).client_mode() + }, + )) + } + } +} + +impl fmt::Debug for WsClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "\nWsClient {}:{}", self.head.method, self.head.uri)?; + writeln!(f, " headers:")?; + for (key, val) in self.head.headers.iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + Ok(()) + } +} + +impl WsClientBuilder { + /// Create new websocket connector + fn new( + uri: U, + ) -> WsClientBuilder< + Base, + impl Service, Response = Io, Error = ConnectError>, + > + where + Uri: TryFrom, + >::Error: Into, + { + let (head, err) = match Uri::try_from(uri) { + Ok(uri) => ( + RequestHead { + uri, + ..Default::default() + }, + None, + ), + Err(e) => (Default::default(), Some(e.into())), + }; + + WsClientBuilder { + err, + origin: None, + protocols: None, + inner: Some(Inner { + head, + connector: Connector::::default(), + addr: None, + max_size: 65_536, + server_mode: false, + timeout: Millis(5_000), + _t: marker::PhantomData, + }), + #[cfg(feature = "cookie")] + cookies: None, + } + } +} + +impl WsClientBuilder +where + T: Service, Response = Io, Error = ConnectError>, +{ + /// Set socket address of the server. + /// + /// This address is used for connection. If address is not + /// provided url's host name get resolved. + pub fn address(&mut self, addr: net::SocketAddr) -> &mut Self { + if let Some(parts) = parts(&mut self.inner, &self.err) { + parts.addr = Some(addr); + } + self + } + + /// Set supported websocket protocols + pub fn protocols(&mut self, protos: U) -> &mut Self + where + U: IntoIterator, + V: AsRef, + { + let mut protos = protos + .into_iter() + .fold(String::new(), |acc, s| acc + s.as_ref() + ","); + protos.pop(); + self.protocols = Some(protos); + self + } + + #[cfg(feature = "cookie")] + /// Set a cookie + pub fn cookie(&mut self, cookie: Cookie<'_>) -> &mut Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Set request Origin + pub fn origin(&mut self, origin: V) -> &mut Self + where + HeaderValue: TryFrom, + HttpError: From, + { + match HeaderValue::try_from(origin) { + Ok(value) => self.origin = Some(value), + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_frame_size(&mut self, size: usize) -> &mut Self { + if let Some(parts) = parts(&mut self.inner, &self.err) { + parts.max_size = size; + } + self + } + + /// Disable payload masking. By default ws client masks frame payload. + pub fn server_mode(&mut self) -> &mut Self { + if let Some(parts) = parts(&mut self.inner, &self.err) { + parts.server_mode = true; + } + self + } + + /// Append a header. + /// + /// Header gets appended to existing header. + /// To override header use `set_header()` method. + pub fn header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: TryFrom, + HeaderValue: TryFrom, + >::Error: Into, + >::Error: Into, + { + if let Some(parts) = parts(&mut self.inner, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => match HeaderValue::try_from(value) { + Ok(value) => { + parts.head.headers.append(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + } + self + } + + /// Insert a header, replaces existing header. + pub fn set_header(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: TryFrom, + HeaderValue: TryFrom, + >::Error: Into, + >::Error: Into, + { + if let Some(parts) = parts(&mut self.inner, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => match HeaderValue::try_from(value) { + Ok(value) => { + parts.head.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + } + self + } + + /// Insert a header only if it is not yet set. + pub fn set_header_if_none(&mut self, key: K, value: V) -> &mut Self + where + HeaderName: TryFrom, + HeaderValue: TryFrom, + >::Error: Into, + >::Error: Into, + { + if let Some(parts) = parts(&mut self.inner, &self.err) { + match HeaderName::try_from(key) { + Ok(key) => { + if !parts.head.headers.contains_key(&key) { + match HeaderValue::try_from(value) { + Ok(value) => { + parts.head.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + } + } + } + Err(e) => self.err = Some(e.into()), + } + } + self + } + + /// Set HTTP basic authorization header + pub fn basic_auth(&mut self, username: U, password: Option<&str>) -> &mut Self + where + U: fmt::Display, + { + let auth = match password { + Some(password) => format!("{}:{}", username, password), + None => format!("{}:", username), + }; + self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth))) + } + + /// Set HTTP bearer authentication header + pub fn bearer_auth(&mut self, token: U) -> &mut Self + where + U: fmt::Display, + { + self.header(AUTHORIZATION, format!("Bearer {}", token)) + } + + /// Set request timeout. + /// + /// Request timeout is the total time before a response must be received. + /// Default value is 5 seconds. + pub fn timeout>(&mut self, timeout: U) -> &mut Self { + if let Some(parts) = parts(&mut self.inner, &self.err) { + parts.timeout = timeout.into(); + } + self + } + + /// Use custom connector + pub fn connector(&mut self, connector: T1) -> WsClientBuilder + where + F1: Filter, + T1: Service, Response = Io, Error = ConnectError>, + { + let inner = self.inner.take().expect("cannot reuse WsClient builder"); + + WsClientBuilder { + inner: Some(Inner { + connector, + head: inner.head, + addr: inner.addr, + max_size: inner.max_size, + server_mode: inner.server_mode, + timeout: inner.timeout, + _t: marker::PhantomData, + }), + err: self.err.take(), + protocols: self.protocols.take(), + origin: self.origin.take(), + #[cfg(feature = "cookie")] + cookies: self.cookies.take(), + } + } + + #[cfg(feature = "openssl")] + /// Use openssl connector. + pub fn openssl( + &mut self, + connector: openssl::SslConnector, + ) -> WsClientBuilder> { + self.connector(openssl::Connector::new(connector)) + } + + #[cfg(feature = "rustls")] + /// Use rustls connector. + pub fn rustls( + &mut self, + config: std::sync::Arc, + ) -> WsClientBuilder> { + self.connector(rustls::Connector::from(config)) + } + + /// This method construct new `WsClientBuilder` + pub fn take(&mut self) -> WsClientBuilder { + WsClientBuilder { + inner: self.inner.take(), + err: self.err.take(), + origin: self.origin.take(), + protocols: self.protocols.take(), + #[cfg(feature = "cookie")] + cookies: self.cookies.take(), + } + } + + /// Complete building process and construct websockets client. + pub fn finish(&mut self) -> Result, WsClientBuilderError> { + if let Some(e) = self.err.take() { + return Err(WsClientBuilderError::Http(e)); + } + + // #[allow(unused_mut)] + let mut inner = self.inner.take().expect("cannot reuse WsClient builder"); + + // validate uri + let uri = &inner.head.uri; + if uri.host().is_none() { + return Err(WsClientBuilderError::MissingHost); + } else if uri.scheme().is_none() { + return Err(WsClientBuilderError::MissingScheme); + } else if let Some(scheme) = uri.scheme() { + match scheme.as_str() { + "http" | "ws" | "https" | "wss" => (), + _ => return Err(WsClientBuilderError::UnknownScheme), + } + } else { + return Err(WsClientBuilderError::UnknownScheme); + } + + if !inner.head.headers.contains_key(header::HOST) { + inner.head.headers.insert( + header::HOST, + HeaderValue::from_str(uri.host().unwrap()).unwrap(), + ); + } + + #[cfg(feature = "cookie")] + { + use percent_encoding::percent_encode; + use std::fmt::Write as FmtWrite; + + // set cookies + if let Some(ref mut jar) = self.cookies { + let mut cookie = String::new(); + for c in jar.delta() { + let name = percent_encode( + c.name().as_bytes(), + crate::http::helpers::USERINFO, + ); + let value = percent_encode( + c.value().as_bytes(), + crate::http::helpers::USERINFO, + ); + let _ = write!(&mut cookie, "; {}={}", name, value); + } + inner.head.headers.insert( + header::COOKIE, + HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), + ); + } + } + + // origin + if let Some(origin) = self.origin.take() { + inner.head.headers.insert(header::ORIGIN, origin); + } + + inner.head.set_connection_type(ConnectionType::Upgrade); + inner + .head + .headers + .insert(header::UPGRADE, HeaderValue::from_static("websocket")); + inner.head.headers.insert( + header::SEC_WEBSOCKET_VERSION, + HeaderValue::from_static("13"), + ); + + if let Some(protocols) = self.protocols.take() { + inner.head.headers.insert( + header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::try_from(protocols.as_str()).unwrap(), + ); + } + + Ok(WsClient { + connector: inner.connector, + head: Rc::new(inner.head), + addr: inner.addr, + max_size: inner.max_size, + server_mode: inner.server_mode, + timeout: inner.timeout, + extra_headers: RefCell::new(None), + _t: marker::PhantomData, + }) + } +} + +#[inline] +fn parts<'a, F, T>( + parts: &'a mut Option>, + err: &Option, +) -> Option<&'a mut Inner> { + if err.is_some() { + return None; + } + parts.as_mut() +} + +impl fmt::Debug for WsClientBuilder { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(ref parts) = self.inner { + writeln!( + f, + "\nWsClientBuilder {}:{}", + parts.head.method, parts.head.uri + )?; + writeln!(f, " headers:")?; + for (key, val) in parts.head.headers.iter() { + writeln!(f, " {:?}: {:?}", key, val)?; + } + } else { + writeln!(f, "WsClientBuilder(Consumed)")?; + } + Ok(()) + } +} + +pub struct WsConnection { + io: Io, + codec: ws::Codec, + res: ClientResponse, +} + +impl WsConnection { + fn new(io: Io, res: ClientResponse, codec: ws::Codec) -> Self { + Self { io, codec, res } + } + + /// Get reference to response + pub fn response(&self) -> &ClientResponse { + &self.res + } +} + +impl WsConnection { + /// Get ws sink + pub fn sink(&self) -> ws::WsSink { + ws::WsSink::new(self.io.get_ref(), self.codec.clone()) + } + + /// Consumes the `WsConnection`, returning it'as underlying I/O stream object + /// and response. + pub fn into_inner(self) -> (Io, ws::Codec, ClientResponse) { + (self.io, self.codec, self.res) + } +} + +impl WsConnection { + // TODO: fix close frame handling + /// Start client websockets with `SinkService` and `mpsc::Receiver` + pub fn start_default(self) -> mpsc::Receiver>> { + let (tx, rx): (_, mpsc::Receiver>>) = + mpsc::channel(); + + rt::spawn(async move { + let io = self.io.get_ref(); + let srv = sink::SinkService::new(tx.clone()).map(|_| None); + + if let Err(err) = self + .start(into_service(move |item| { + let io = io.clone(); + let close = matches!(item, ws::Frame::Close(_)); + let fut = srv.call(Ok::<_, WsError<()>>(item)); + async move { + let result = fut.await.map_err(|_| ()); + if close { + io.close(); + } + result + } + })) + .await + { + let _ = tx.send(Err(err)); + } + }); + + rx + } + + /// Start client websockets service. + pub async fn start(self, service: U) -> Result<(), WsError> + where + T: Service> + 'static, + U: IntoService, + { + let service = apply_fn( + service.into_service().map_err(WsError::Service), + |req, srv| match req { + DispatchItem::Item(item) => Either::Left(srv.call(item)), + DispatchItem::WBackPressureEnabled + | DispatchItem::WBackPressureDisabled => Either::Right(Ready::Ok(None)), + DispatchItem::KeepAliveTimeout => { + Either::Right(Ready::Err(WsError::KeepAlive)) + } + DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => { + Either::Right(Ready::Err(WsError::Protocol(e))) + } + DispatchItem::Disconnect(e) => { + Either::Right(Ready::Err(WsError::Disconnected(e))) + } + }, + ); + + Dispatcher::new(self.io, self.codec, service, Default::default()).await + } +} + +impl WsConnection { + /// Convert I/O stream to boxed stream; + pub fn seal(self) -> WsConnection { + WsConnection { + io: self.io.seal(), + codec: self.codec, + res: self.res, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[crate::rt_test] + async fn test_debug() { + let mut builder = WsClient::build("http://localhost") + .header("x-test", "111") + .take(); + let repr = format!("{:?}", builder); + assert!(repr.contains("WsClientBuilder")); + assert!(repr.contains("x-test")); + + let client = builder.finish().unwrap(); + let repr = format!("{:?}", client); + assert!(repr.contains("WsClient")); + assert!(repr.contains("x-test")); + } + + #[crate::rt_test] + async fn header_override() { + let req = WsClient::build("http://localhost") + .header(header::CONTENT_TYPE, "111") + .set_header(header::CONTENT_TYPE, "222") + .finish() + .unwrap(); + + assert_eq!( + req.head + .headers + .get(header::CONTENT_TYPE) + .unwrap() + .to_str() + .unwrap(), + "222" + ); + } + + #[test] + fn basic_errs() { + let err = WsClient::build("localhost").finish().err().unwrap(); + assert!(matches!(err, WsClientBuilderError::MissingScheme)); + let err = WsClient::build("unknown://localhost") + .finish() + .err() + .unwrap(); + assert!(matches!(err, WsClientBuilderError::UnknownScheme)); + let err = WsClient::build("/").finish().err().unwrap(); + assert!(matches!(err, WsClientBuilderError::MissingHost)); + } + + #[crate::rt_test] + async fn basic_auth() { + let client = WsClient::build("http://localhost") + .basic_auth("username", Some("password")) + .finish() + .unwrap(); + assert_eq!( + client + .head + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" + ); + + let client = WsClient::build("http://localhost") + .basic_auth("username", None) + .finish() + .unwrap(); + assert_eq!( + client + .head + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Basic dXNlcm5hbWU6" + ); + + client.set_basic_auth("username", Some("password")).unwrap(); + assert_eq!( + client + .extra_headers + .borrow() + .as_ref() + .unwrap() + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Basic dXNlcm5hbWU6cGFzc3dvcmQ=" + ); + } + + #[crate::rt_test] + async fn bearer_auth() { + let client = WsClient::build("http://localhost") + .bearer_auth("someS3cr3tAutht0k3n") + .finish() + .unwrap(); + assert_eq!( + client + .head + .headers + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Bearer someS3cr3tAutht0k3n" + ); + + let _ = client.set_bearer_auth("someS3cr3tAutht0k2n"); + assert_eq!( + client + .extra_headers + .borrow() + .as_ref() + .unwrap() + .get(header::AUTHORIZATION) + .unwrap() + .to_str() + .unwrap(), + "Bearer someS3cr3tAutht0k2n" + ); + + let _ = client.connect(); + } + + #[cfg(feature = "cookie")] + #[crate::rt_test] + async fn basics() { + let mut builder = WsClient::build("http://localhost/") + .origin("test-origin") + .max_frame_size(100) + .server_mode() + .protocols(&["v1", "v2"]) + .set_header_if_none(header::CONTENT_TYPE, "json") + .set_header_if_none(header::CONTENT_TYPE, "text") + .cookie(Cookie::build("cookie1", "value1").finish()) + .take(); + assert_eq!( + builder.origin.as_ref().unwrap().to_str().unwrap(), + "test-origin" + ); + assert_eq!(builder.inner.as_ref().unwrap().max_size, 100); + assert_eq!(builder.inner.as_ref().unwrap().server_mode, true); + assert_eq!(builder.protocols, Some("v1,v2".to_string())); + + let client = builder.finish().unwrap(); + assert_eq!( + client.head.headers.get(header::CONTENT_TYPE).unwrap(), + header::HeaderValue::from_static("json") + ); + + let _ = client.connect().await; + + assert!(WsClient::build("/").finish().is_err()); + assert!(WsClient::build("http:///test").finish().is_err()); + assert!(WsClient::build("hmm://test.com/").finish().is_err()); + } +} diff --git a/ntex/src/ws/codec.rs b/ntex/src/ws/codec.rs index b1b5e1e5..60d95002 100644 --- a/ntex/src/ws/codec.rs +++ b/ntex/src/ws/codec.rs @@ -3,9 +3,9 @@ use std::cell::Cell; use crate::codec::{Decoder, Encoder}; use crate::util::{ByteString, Bytes, BytesMut}; +use super::error::ProtocolError; use super::frame::Parser; use super::proto::{CloseReason, OpCode}; -use super::ProtocolError; /// WebSocket message #[derive(Debug, PartialEq)] diff --git a/ntex/src/ws/error.rs b/ntex/src/ws/error.rs new file mode 100644 index 00000000..8f103bb1 --- /dev/null +++ b/ntex/src/ws/error.rs @@ -0,0 +1,128 @@ +//! WebSocket protocol related errors. +use std::{error, io}; + +use derive_more::{Display, From}; + +use crate::connect::ConnectError; +use crate::http::error::{HttpError, ParseError}; +use crate::http::{header::HeaderValue, StatusCode}; +use crate::util::Either; + +use super::OpCode; + +/// Websocket service errors +#[derive(Debug, Display)] +pub enum WsError { + Service(E), + /// Keep-alive error + KeepAlive, + /// Ws protocol level error + Protocol(ProtocolError), + /// Peer has been disconnected + #[display(fmt = "Peer has been disconnected: {:?}", _0)] + Disconnected(Option), +} + +/// Websocket protocol errors +#[derive(Debug, Display, From)] +pub enum ProtocolError { + /// Received an unmasked frame from client + #[display(fmt = "Received an unmasked frame from client")] + UnmaskedFrame, + /// Received a masked frame from server + #[display(fmt = "Received a masked frame from server")] + MaskedFrame, + /// Encountered invalid opcode + #[display(fmt = "Invalid opcode: {}", _0)] + InvalidOpcode(u8), + /// Invalid control frame length + #[display(fmt = "Invalid control frame length: {}", _0)] + InvalidLength(usize), + /// Bad web socket op code + #[display(fmt = "Bad web socket op code")] + BadOpCode, + /// A payload reached size limit. + #[display(fmt = "A payload reached size limit.")] + Overflow, + /// Continuation is not started + #[display(fmt = "Continuation is not started.")] + ContinuationNotStarted, + /// Received new continuation but it is already started + #[display(fmt = "Received new continuation but it is already started")] + ContinuationStarted, + /// Unknown continuation fragment + #[display(fmt = "Unknown continuation fragment.")] + ContinuationFragment(OpCode), +} + +impl std::error::Error for ProtocolError {} + +/// Websocket client error +#[derive(Debug, Display, From)] +pub enum WsClientBuilderError { + #[display(fmt = "Missing url scheme")] + MissingScheme, + #[display(fmt = "Unknown url scheme")] + UnknownScheme, + #[display(fmt = "Missing host name")] + MissingHost, + #[display(fmt = "Url parse error: {}", _0)] + Http(HttpError), +} + +impl std::error::Error for WsClientBuilderError {} + +/// Websocket client error +#[derive(Debug, Display, From)] +pub enum WsClientError { + /// Invalid response + #[display(fmt = "Invalid response")] + InvalidResponse(ParseError), + /// Invalid response status + #[display(fmt = "Invalid response status")] + InvalidResponseStatus(StatusCode), + /// Invalid upgrade header + #[display(fmt = "Invalid upgrade header")] + InvalidUpgradeHeader, + /// Invalid connection header + #[display(fmt = "Invalid connection header")] + InvalidConnectionHeader(HeaderValue), + /// Missing CONNECTION header + #[display(fmt = "Missing CONNECTION header")] + MissingConnectionHeader, + /// Missing SEC-WEBSOCKET-ACCEPT header + #[display(fmt = "Missing SEC-WEBSOCKET-ACCEPT header")] + MissingWebSocketAcceptHeader, + /// Invalid challenge response + #[display(fmt = "Invalid challenge response")] + InvalidChallengeResponse(String, HeaderValue), + /// Protocol error + #[display(fmt = "{}", _0)] + Protocol(ProtocolError), + /// Response took too long + #[display(fmt = "Timeout out while waiting for response")] + Timeout, + /// Failed to connect to host + #[display(fmt = "Failed to connect to host: {}", _0)] + Connect(ConnectError), + /// Connector has been disconnected + #[display(fmt = "Connector has been disconnected: {:?}", _0)] + Disconnected(Option), +} + +impl error::Error for WsClientError {} + +impl From> for WsClientError { + fn from(err: Either) -> Self { + match err { + Either::Left(err) => WsClientError::InvalidResponse(err), + Either::Right(err) => WsClientError::Disconnected(Some(err)), + } + } +} + +impl From> for WsClientError { + fn from(err: Either) -> Self { + WsClientError::Disconnected(Some(err.into_inner())) + } +} diff --git a/ntex/src/ws/frame.rs b/ntex/src/ws/frame.rs index 9f1028ac..e713b64b 100644 --- a/ntex/src/ws/frame.rs +++ b/ntex/src/ws/frame.rs @@ -4,7 +4,7 @@ use log::debug; use nanorand::{Rng, WyRand}; use super::proto::{CloseCode, CloseReason, OpCode}; -use super::{mask::apply_mask, ProtocolError}; +use super::{error::ProtocolError, mask::apply_mask}; use crate::util::{Buf, BufMut, BytesMut}; /// WebSocket frame parser. diff --git a/ntex/src/ws/mod.rs b/ntex/src/ws/mod.rs index 4ef8b14a..4a2947d4 100644 --- a/ntex/src/ws/mod.rs +++ b/ntex/src/ws/mod.rs @@ -3,10 +3,7 @@ //! To setup a `WebSocket`, first do web socket handshake then on success //! convert `Payload` into a `WsStream` stream and then use `WsWriter` to //! communicate with the peer. -use std::io; - -use derive_more::{Display, From}; - +mod client; mod codec; mod frame; mod mask; @@ -14,55 +11,11 @@ mod proto; mod sink; mod stream; +pub mod error; + +pub use self::client::{WsClient, WsClientBuilder, WsConnection}; pub use self::codec::{Codec, Frame, Item, Message}; pub use self::frame::Parser; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; pub use self::sink::WsSink; pub use self::stream::{StreamDecoder, StreamEncoder}; - -/// Websocket service errors -#[derive(Debug, Display)] -pub enum WsError { - Service(E), - KeepAlive, - Disconnected, - Protocol(ProtocolError), - Io(io::Error), -} - -/// Websocket protocol errors -#[derive(Debug, Display, From)] -pub enum ProtocolError { - /// Received an unmasked frame from client - #[display(fmt = "Received an unmasked frame from client")] - UnmaskedFrame, - /// Received a masked frame from server - #[display(fmt = "Received a masked frame from server")] - MaskedFrame, - /// Encountered invalid opcode - #[display(fmt = "Invalid opcode: {}", _0)] - InvalidOpcode(u8), - /// Invalid control frame length - #[display(fmt = "Invalid control frame length: {}", _0)] - InvalidLength(usize), - /// Bad web socket op code - #[display(fmt = "Bad web socket op code")] - BadOpCode, - /// A payload reached size limit. - #[display(fmt = "A payload reached size limit.")] - Overflow, - /// Continuation is not started - #[display(fmt = "Continuation is not started.")] - ContinuationNotStarted, - /// Received new continuation but it is already started - #[display(fmt = "Received new continuation but it is already started")] - ContinuationStarted, - /// Unknown continuation fragment - #[display(fmt = "Unknown continuation fragment.")] - ContinuationFragment(OpCode), - /// IO Error - #[display(fmt = "IO Error: {:?}", _0)] - Io(io::Error), -} - -impl std::error::Error for ProtocolError {} diff --git a/ntex/src/ws/sink.rs b/ntex/src/ws/sink.rs index 704ef3b8..b3b8a248 100644 --- a/ntex/src/ws/sink.rs +++ b/ntex/src/ws/sink.rs @@ -19,7 +19,7 @@ impl WsSink { pub fn send( &self, item: ws::Message, - ) -> impl Future> { + ) -> impl Future> { let inner = self.0.clone(); async move { diff --git a/ntex/src/ws/stream.rs b/ntex/src/ws/stream.rs index 3d9a482a..d11f42d2 100644 --- a/ntex/src/ws/stream.rs +++ b/ntex/src/ws/stream.rs @@ -2,7 +2,7 @@ use std::{ cell::RefCell, fmt, marker::PhantomData, pin::Pin, rc::Rc, task::Context, task::Poll, }; -use super::{Codec, Frame, Message, ProtocolError}; +use super::{error::ProtocolError, Codec, Frame, Message}; use crate::util::{Bytes, BytesMut}; use crate::{codec::Decoder, codec::Encoder, Sink, Stream}; diff --git a/ntex/tests/http_ws.rs b/ntex/tests/http_ws.rs index 2715213c..eee64999 100644 --- a/ntex/tests/http_ws.rs +++ b/ntex/tests/http_ws.rs @@ -49,7 +49,7 @@ impl Service for WsService { io.encode((res, body::BodySize::None).into(), &codec) .unwrap(); - Dispatcher::new(io.into_boxed(), ws::Codec::new(), service, Timer::default()) + Dispatcher::new(io.seal(), ws::Codec::new(), service, Timer::default()) .await .map_err(|_| panic!()) }; @@ -96,7 +96,7 @@ async fn test_simple() { let conn = srv.ws().await.unwrap(); assert_eq!(conn.response().status(), StatusCode::SWITCHING_PROTOCOLS); - let (_, io, codec) = conn.into_inner(); + let (io, codec, _) = conn.into_inner(); io.send(&codec, ws::Message::Text(ByteString::from_static("text"))) .await .unwrap(); diff --git a/ntex/tests/http_awc_ws.rs b/ntex/tests/http_ws_client.rs similarity index 96% rename from ntex/tests/http_awc_ws.rs rename to ntex/tests/http_ws_client.rs index c8f33830..b139b3dd 100644 --- a/ntex/tests/http_awc_ws.rs +++ b/ntex/tests/http_ws_client.rs @@ -41,7 +41,7 @@ async fn test_simple() { // start websocket service Dispatcher::new( - io.into_boxed(), + io.seal(), ws::Codec::default(), ws_service, Default::default(), @@ -53,7 +53,7 @@ async fn test_simple() { }); // client service - let (_, io, codec) = srv.ws().await.unwrap().into_inner(); + let (io, codec, _) = srv.ws().await.unwrap().into_inner(); io.send(&codec, ws::Message::Text(ByteString::from_static("text"))) .await .unwrap(); diff --git a/ntex/tests/web_ws.rs b/ntex/tests/web_ws.rs index 32523b0b..2df011e2 100644 --- a/ntex/tests/web_ws.rs +++ b/ntex/tests/web_ws.rs @@ -37,7 +37,7 @@ async fn web_ws() { }); // client service - let (_, io, codec) = srv.ws().await.unwrap().into_inner(); + let (io, codec, _) = srv.ws().await.unwrap().into_inner(); io.send(&codec, ws::Message::Text(ByteString::from_static("text"))) .await .unwrap();