mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-04 13:27:39 +03:00
commit
1e44bf0ecf
28 changed files with 1205 additions and 854 deletions
|
@ -1,5 +1,11 @@
|
||||||
# Changes
|
# Changes
|
||||||
|
|
||||||
|
## [0.5.0-b.3] - 2021-12-xx
|
||||||
|
|
||||||
|
* Remove websocket support from http::client
|
||||||
|
|
||||||
|
* Add standalone ws::client
|
||||||
|
|
||||||
## [0.5.0-b.2] - 2021-12-22
|
## [0.5.0-b.2] - 2021-12-22
|
||||||
|
|
||||||
* Refactor write back-pressure for http1
|
* Refactor write back-pressure for http1
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "ntex"
|
name = "ntex"
|
||||||
version = "0.5.0-b.2"
|
version = "0.5.0-b.3"
|
||||||
authors = ["ntex contributors <team@ntex.rs>"]
|
authors = ["ntex contributors <team@ntex.rs>"]
|
||||||
description = "Framework for composable network services"
|
description = "Framework for composable network services"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
|
@ -17,6 +17,15 @@ pub struct Connector<T> {
|
||||||
inner: TlsConnector,
|
inner: TlsConnector,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> From<std::sync::Arc<ClientConfig>> for Connector<T> {
|
||||||
|
fn from(cfg: std::sync::Arc<ClientConfig>) -> Self {
|
||||||
|
Connector {
|
||||||
|
inner: TlsConnector::new(cfg),
|
||||||
|
connector: BaseConnector::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T> Connector<T> {
|
impl<T> Connector<T> {
|
||||||
pub fn new(config: ClientConfig) -> Self {
|
pub fn new(config: ClientConfig) -> Self {
|
||||||
Connector {
|
Connector {
|
||||||
|
|
|
@ -1,23 +1,13 @@
|
||||||
use std::{future::Future, net, pin::Pin};
|
use std::{future::Future, net, pin::Pin};
|
||||||
|
|
||||||
use crate::http::body::Body;
|
use crate::http::body::Body;
|
||||||
use crate::http::h1::ClientCodec;
|
use crate::http::RequestHeadType;
|
||||||
use crate::http::{RequestHeadType, ResponseHead};
|
use crate::service::Service;
|
||||||
use crate::io::IoBoxed;
|
|
||||||
use crate::Service;
|
|
||||||
|
|
||||||
use super::error::{ConnectError, SendRequestError};
|
use super::error::{ConnectError, SendRequestError};
|
||||||
use super::response::ClientResponse;
|
use super::response::ClientResponse;
|
||||||
use super::{Connect as ClientConnect, Connection};
|
use super::{Connect as ClientConnect, Connection};
|
||||||
|
|
||||||
pub(crate) type TunnelFuture = Pin<
|
|
||||||
Box<
|
|
||||||
dyn Future<
|
|
||||||
Output = Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError>,
|
|
||||||
>,
|
|
||||||
>,
|
|
||||||
>;
|
|
||||||
|
|
||||||
pub(super) struct ConnectorWrapper<T>(pub(crate) T);
|
pub(super) struct ConnectorWrapper<T>(pub(crate) T);
|
||||||
|
|
||||||
pub(super) trait Connect {
|
pub(super) trait Connect {
|
||||||
|
@ -27,13 +17,6 @@ pub(super) trait Connect {
|
||||||
body: Body,
|
body: Body,
|
||||||
addr: Option<net::SocketAddr>,
|
addr: Option<net::SocketAddr>,
|
||||||
) -> Pin<Box<dyn Future<Output = Result<ClientResponse, SendRequestError>>>>;
|
) -> Pin<Box<dyn Future<Output = Result<ClientResponse, SendRequestError>>>>;
|
||||||
|
|
||||||
/// Send request, returns Response and Framed
|
|
||||||
fn open_tunnel(
|
|
||||||
&self,
|
|
||||||
head: RequestHeadType,
|
|
||||||
addr: Option<net::SocketAddr>,
|
|
||||||
) -> TunnelFuture;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Connect for ConnectorWrapper<T>
|
impl<T> Connect for ConnectorWrapper<T>
|
||||||
|
@ -63,23 +46,4 @@ where
|
||||||
.map(|(head, payload)| ClientResponse::new(head, payload))
|
.map(|(head, payload)| ClientResponse::new(head, payload))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn open_tunnel(
|
|
||||||
&self,
|
|
||||||
head: RequestHeadType,
|
|
||||||
addr: Option<net::SocketAddr>,
|
|
||||||
) -> TunnelFuture {
|
|
||||||
// connect to the host
|
|
||||||
let fut = self.0.call(ClientConnect {
|
|
||||||
uri: head.as_ref().uri.clone(),
|
|
||||||
addr,
|
|
||||||
});
|
|
||||||
|
|
||||||
Box::pin(async move {
|
|
||||||
let connection = fut.await?;
|
|
||||||
|
|
||||||
// send request
|
|
||||||
connection.open_tunnel(head).await
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ use h2::client::SendRequest;
|
||||||
use ntex_tls::types::HttpProtocol;
|
use ntex_tls::types::HttpProtocol;
|
||||||
|
|
||||||
use crate::http::body::MessageBody;
|
use crate::http::body::MessageBody;
|
||||||
use crate::http::h1::ClientCodec;
|
|
||||||
use crate::http::message::{RequestHeadType, ResponseHead};
|
use crate::http::message::{RequestHeadType, ResponseHead};
|
||||||
use crate::http::payload::Payload;
|
use crate::http::payload::Payload;
|
||||||
use crate::io::IoBoxed;
|
use crate::io::IoBoxed;
|
||||||
|
@ -91,24 +90,4 @@ impl Connection {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send request, returns Response and Framed
|
|
||||||
pub(super) async fn open_tunnel<H: Into<RequestHeadType>>(
|
|
||||||
mut self,
|
|
||||||
head: H,
|
|
||||||
) -> Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError> {
|
|
||||||
match self.io.take().unwrap() {
|
|
||||||
ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await,
|
|
||||||
ConnectionType::H2(io) => {
|
|
||||||
if let Some(mut pool) = self.pool.take() {
|
|
||||||
pool.release(Connection::new(
|
|
||||||
ConnectionType::H2(io),
|
|
||||||
self.created,
|
|
||||||
None,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
Err(SendRequestError::TunnelNotSupported)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,7 +55,7 @@ impl Connector {
|
||||||
let conn = Connector {
|
let conn = Connector {
|
||||||
connector: boxed::service(
|
connector: boxed::service(
|
||||||
TcpConnector::new()
|
TcpConnector::new()
|
||||||
.map(|io| io.into_boxed())
|
.map(|io| io.seal())
|
||||||
.map_err(ConnectError::from),
|
.map_err(ConnectError::from),
|
||||||
),
|
),
|
||||||
ssl_connector: None,
|
ssl_connector: None,
|
||||||
|
@ -184,11 +184,8 @@ impl Connector {
|
||||||
> + 'static,
|
> + 'static,
|
||||||
F: Filter,
|
F: Filter,
|
||||||
{
|
{
|
||||||
self.connector = boxed::service(
|
self.connector =
|
||||||
connector
|
boxed::service(connector.map(|io| io.seal()).map_err(ConnectError::from));
|
||||||
.map(|io| io.into_boxed())
|
|
||||||
.map_err(ConnectError::from),
|
|
||||||
);
|
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,9 +200,7 @@ impl Connector {
|
||||||
F: Filter,
|
F: Filter,
|
||||||
{
|
{
|
||||||
self.ssl_connector = Some(boxed::service(
|
self.ssl_connector = Some(boxed::service(
|
||||||
connector
|
connector.map(|io| io.seal()).map_err(ConnectError::from),
|
||||||
.map(|io| io.into_boxed())
|
|
||||||
.map_err(ConnectError::from),
|
|
||||||
));
|
));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,53 +8,7 @@ use serde_json::error::Error as JsonError;
|
||||||
use crate::connect::openssl::{HandshakeError, SslError};
|
use crate::connect::openssl::{HandshakeError, SslError};
|
||||||
|
|
||||||
use crate::http::error::{HttpError, ParseError, PayloadError};
|
use crate::http::error::{HttpError, ParseError, PayloadError};
|
||||||
use crate::http::header::HeaderValue;
|
|
||||||
use crate::http::StatusCode;
|
|
||||||
use crate::util::Either;
|
use crate::util::Either;
|
||||||
use crate::ws::ProtocolError;
|
|
||||||
|
|
||||||
/// Websocket client error
|
|
||||||
#[derive(Debug, Display, From)]
|
|
||||||
pub enum WsClientError {
|
|
||||||
/// Invalid response status
|
|
||||||
#[display(fmt = "Invalid response status")]
|
|
||||||
InvalidResponseStatus(StatusCode),
|
|
||||||
/// Invalid upgrade header
|
|
||||||
#[display(fmt = "Invalid upgrade header")]
|
|
||||||
InvalidUpgradeHeader,
|
|
||||||
/// Invalid connection header
|
|
||||||
#[display(fmt = "Invalid connection header")]
|
|
||||||
InvalidConnectionHeader(HeaderValue),
|
|
||||||
/// Missing CONNECTION header
|
|
||||||
#[display(fmt = "Missing CONNECTION header")]
|
|
||||||
MissingConnectionHeader,
|
|
||||||
/// Missing SEC-WEBSOCKET-ACCEPT header
|
|
||||||
#[display(fmt = "Missing SEC-WEBSOCKET-ACCEPT header")]
|
|
||||||
MissingWebSocketAcceptHeader,
|
|
||||||
/// Invalid challenge response
|
|
||||||
#[display(fmt = "Invalid challenge response")]
|
|
||||||
InvalidChallengeResponse(String, HeaderValue),
|
|
||||||
/// Protocol error
|
|
||||||
#[display(fmt = "{}", _0)]
|
|
||||||
Protocol(ProtocolError),
|
|
||||||
/// Send request error
|
|
||||||
#[display(fmt = "{}", _0)]
|
|
||||||
SendRequest(SendRequestError),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for WsClientError {}
|
|
||||||
|
|
||||||
impl From<InvalidUrl> for WsClientError {
|
|
||||||
fn from(err: InvalidUrl) -> Self {
|
|
||||||
WsClientError::SendRequest(err.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<HttpError> for WsClientError {
|
|
||||||
fn from(err: HttpError) -> Self {
|
|
||||||
WsClientError::SendRequest(err.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A set of errors that can occur during parsing json payloads
|
/// A set of errors that can occur during parsing json payloads
|
||||||
#[derive(Debug, Display, From)]
|
#[derive(Debug, Display, From)]
|
||||||
|
@ -107,16 +61,12 @@ pub enum ConnectError {
|
||||||
Timeout,
|
Timeout,
|
||||||
|
|
||||||
/// Connector has been disconnected
|
/// Connector has been disconnected
|
||||||
#[display(fmt = "Internal error: connector has been disconnected")]
|
#[display(fmt = "Connector has been disconnected")]
|
||||||
Disconnected,
|
Disconnected(Option<io::Error>),
|
||||||
|
|
||||||
/// Unresolved host name
|
/// Unresolved host name
|
||||||
#[display(fmt = "Connector received `Connect` method with unresolved host")]
|
#[display(fmt = "Connector received `Connect` method with unresolved host")]
|
||||||
Unresolved,
|
Unresolved,
|
||||||
|
|
||||||
/// Connection io error
|
|
||||||
#[display(fmt = "{}", _0)]
|
|
||||||
Io(io::Error),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::error::Error for ConnectError {}
|
impl std::error::Error for ConnectError {}
|
||||||
|
@ -128,7 +78,7 @@ impl From<crate::connect::ConnectError> for ConnectError {
|
||||||
crate::connect::ConnectError::NoRecords => ConnectError::NoRecords,
|
crate::connect::ConnectError::NoRecords => ConnectError::NoRecords,
|
||||||
crate::connect::ConnectError::InvalidInput => panic!(),
|
crate::connect::ConnectError::InvalidInput => panic!(),
|
||||||
crate::connect::ConnectError::Unresolved => ConnectError::Unresolved,
|
crate::connect::ConnectError::Unresolved => ConnectError::Unresolved,
|
||||||
crate::connect::ConnectError::Io(e) => ConnectError::Io(e),
|
crate::connect::ConnectError::Io(e) => ConnectError::Disconnected(Some(e)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,7 +82,7 @@ where
|
||||||
);
|
);
|
||||||
result
|
result
|
||||||
} else {
|
} else {
|
||||||
return Err(SendRequestError::from(ConnectError::Disconnected));
|
return Err(SendRequestError::from(ConnectError::Disconnected(None)));
|
||||||
};
|
};
|
||||||
|
|
||||||
match codec.message_type() {
|
match codec.message_type() {
|
||||||
|
@ -98,22 +98,6 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) async fn open_tunnel(
|
|
||||||
io: IoBoxed,
|
|
||||||
head: RequestHeadType,
|
|
||||||
) -> Result<(ResponseHead, IoBoxed, h1::ClientCodec), SendRequestError> {
|
|
||||||
// create Framed and send request
|
|
||||||
let codec = h1::ClientCodec::default();
|
|
||||||
io.send(&codec, (head, BodySize::None).into()).await?;
|
|
||||||
|
|
||||||
// read response
|
|
||||||
if let Some(head) = io.recv(&codec).await? {
|
|
||||||
Ok((head, io, codec))
|
|
||||||
} else {
|
|
||||||
Err(SendRequestError::from(ConnectError::Disconnected))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// send request body to the peer
|
/// send request body to the peer
|
||||||
pub(super) async fn send_body<B>(
|
pub(super) async fn send_body<B>(
|
||||||
mut body: B,
|
mut body: B,
|
||||||
|
|
|
@ -31,7 +31,6 @@ mod request;
|
||||||
mod response;
|
mod response;
|
||||||
mod sender;
|
mod sender;
|
||||||
mod test;
|
mod test;
|
||||||
pub mod ws;
|
|
||||||
|
|
||||||
pub use self::builder::ClientBuilder;
|
pub use self::builder::ClientBuilder;
|
||||||
pub use self::connection::Connection;
|
pub use self::connection::Connection;
|
||||||
|
@ -193,17 +192,4 @@ impl Client {
|
||||||
{
|
{
|
||||||
self.request(Method::OPTIONS, url)
|
self.request(Method::OPTIONS, url)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct WebSockets request.
|
|
||||||
pub fn ws<U>(&self, url: U) -> ws::WsRequest
|
|
||||||
where
|
|
||||||
Uri: TryFrom<U>,
|
|
||||||
<Uri as TryFrom<U>>::Error: Into<HttpError>,
|
|
||||||
{
|
|
||||||
let mut req = ws::WsRequest::new(url, self.0.clone());
|
|
||||||
for (key, value) in self.0.headers.iter() {
|
|
||||||
req.head.headers.insert(key.clone(), value.clone());
|
|
||||||
}
|
|
||||||
req
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,7 +136,7 @@ where
|
||||||
OpenConnection::spawn(key, tx, inner, connector.call(req));
|
OpenConnection::spawn(key, tx, inner, connector.call(req));
|
||||||
|
|
||||||
match rx.await {
|
match rx.await {
|
||||||
Err(_) => Err(ConnectError::Disconnected),
|
Err(_) => Err(ConnectError::Disconnected(None)),
|
||||||
Ok(res) => res,
|
Ok(res) => res,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -148,7 +148,7 @@ where
|
||||||
);
|
);
|
||||||
let rx = inner.borrow_mut().wait_for(req);
|
let rx = inner.borrow_mut().wait_for(req);
|
||||||
match rx.await {
|
match rx.await {
|
||||||
Err(_) => Err(ConnectError::Disconnected),
|
Err(_) => Err(ConnectError::Disconnected(None)),
|
||||||
Ok(res) => res,
|
Ok(res) => res,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -536,7 +536,7 @@ mod tests {
|
||||||
fn_service(move |req| {
|
fn_service(move |req| {
|
||||||
let (client, server) = Io::create();
|
let (client, server) = Io::create();
|
||||||
store2.borrow_mut().push((req, server));
|
store2.borrow_mut().push((req, server));
|
||||||
Box::pin(async move { Ok(nio::Io::new(client).into_boxed()) })
|
Box::pin(async move { Ok(nio::Io::new(client).seal()) })
|
||||||
}),
|
}),
|
||||||
Duration::from_secs(10),
|
Duration::from_secs(10),
|
||||||
Duration::from_secs(10),
|
Duration::from_secs(10),
|
||||||
|
|
|
@ -63,6 +63,10 @@ impl ClientResponse {
|
||||||
ClientResponse { head, payload }
|
ClientResponse { head, payload }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn with_empty_payload(head: ResponseHead) -> Self {
|
||||||
|
ClientResponse::new(head, Payload::None)
|
||||||
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub(crate) fn head(&self) -> &ResponseHead {
|
pub(crate) fn head(&self) -> &ResponseHead {
|
||||||
&self.head
|
&self.head
|
||||||
|
|
|
@ -1,597 +0,0 @@
|
||||||
//! Websockets client
|
|
||||||
use std::{convert::TryFrom, fmt, net::SocketAddr, rc::Rc, str};
|
|
||||||
|
|
||||||
#[cfg(feature = "cookie")]
|
|
||||||
use coo_kie::{Cookie, CookieJar};
|
|
||||||
use nanorand::{Rng, WyRand};
|
|
||||||
|
|
||||||
use crate::http::error::HttpError;
|
|
||||||
use crate::http::header::{self, HeaderName, HeaderValue, AUTHORIZATION};
|
|
||||||
use crate::http::{ConnectionType, Payload, RequestHead, StatusCode, Uri};
|
|
||||||
use crate::io::{DispatchItem, Dispatcher, IoBoxed};
|
|
||||||
use crate::service::{apply_fn, into_service, IntoService, Service};
|
|
||||||
use crate::util::{sink, Either, Ready};
|
|
||||||
use crate::{channel::mpsc, rt, time::timeout, ws};
|
|
||||||
|
|
||||||
pub use crate::ws::{CloseCode, CloseReason, Frame, Message};
|
|
||||||
|
|
||||||
use super::error::{InvalidUrl, SendRequestError, WsClientError};
|
|
||||||
use super::response::ClientResponse;
|
|
||||||
use super::ClientConfig;
|
|
||||||
|
|
||||||
/// `WebSocket` connection
|
|
||||||
pub struct WsRequest {
|
|
||||||
pub(crate) head: RequestHead,
|
|
||||||
err: Option<HttpError>,
|
|
||||||
origin: Option<HeaderValue>,
|
|
||||||
protocols: Option<String>,
|
|
||||||
addr: Option<SocketAddr>,
|
|
||||||
max_size: usize,
|
|
||||||
server_mode: bool,
|
|
||||||
#[cfg(feature = "cookie")]
|
|
||||||
cookies: Option<CookieJar>,
|
|
||||||
config: Rc<ClientConfig>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WsRequest {
|
|
||||||
/// Create new websocket connection
|
|
||||||
pub(super) fn new<U>(uri: U, config: Rc<ClientConfig>) -> Self
|
|
||||||
where
|
|
||||||
Uri: TryFrom<U>,
|
|
||||||
<Uri as TryFrom<U>>::Error: Into<HttpError>,
|
|
||||||
{
|
|
||||||
let (head, err) = match Uri::try_from(uri) {
|
|
||||||
Ok(uri) => (
|
|
||||||
RequestHead {
|
|
||||||
uri,
|
|
||||||
..Default::default()
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
Err(e) => (Default::default(), Some(e.into())),
|
|
||||||
};
|
|
||||||
|
|
||||||
WsRequest {
|
|
||||||
head,
|
|
||||||
err,
|
|
||||||
config,
|
|
||||||
addr: None,
|
|
||||||
origin: None,
|
|
||||||
protocols: None,
|
|
||||||
max_size: 65_536,
|
|
||||||
server_mode: false,
|
|
||||||
#[cfg(feature = "cookie")]
|
|
||||||
cookies: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set socket address of the server.
|
|
||||||
///
|
|
||||||
/// This address is used for connection. If address is not
|
|
||||||
/// provided url's host name get resolved.
|
|
||||||
pub fn address(mut self, addr: SocketAddr) -> Self {
|
|
||||||
self.addr = Some(addr);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set supported websocket protocols
|
|
||||||
pub fn protocols<U, V>(mut self, protos: U) -> Self
|
|
||||||
where
|
|
||||||
U: IntoIterator<Item = V>,
|
|
||||||
V: AsRef<str>,
|
|
||||||
{
|
|
||||||
let mut protos = protos
|
|
||||||
.into_iter()
|
|
||||||
.fold(String::new(), |acc, s| acc + s.as_ref() + ",");
|
|
||||||
protos.pop();
|
|
||||||
self.protocols = Some(protos);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "cookie")]
|
|
||||||
/// Set a cookie
|
|
||||||
pub fn cookie(mut self, cookie: Cookie<'_>) -> Self {
|
|
||||||
if self.cookies.is_none() {
|
|
||||||
let mut jar = CookieJar::new();
|
|
||||||
jar.add(cookie.into_owned());
|
|
||||||
self.cookies = Some(jar)
|
|
||||||
} else {
|
|
||||||
self.cookies.as_mut().unwrap().add(cookie.into_owned());
|
|
||||||
}
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set request Origin
|
|
||||||
pub fn origin<V, E>(mut self, origin: V) -> Self
|
|
||||||
where
|
|
||||||
HeaderValue: TryFrom<V, Error = E>,
|
|
||||||
HttpError: From<E>,
|
|
||||||
{
|
|
||||||
match HeaderValue::try_from(origin) {
|
|
||||||
Ok(value) => self.origin = Some(value),
|
|
||||||
Err(e) => self.err = Some(e.into()),
|
|
||||||
}
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set max frame size
|
|
||||||
///
|
|
||||||
/// By default max size is set to 64kb
|
|
||||||
pub fn max_frame_size(mut self, size: usize) -> Self {
|
|
||||||
self.max_size = size;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Disable payload masking. By default ws client masks frame payload.
|
|
||||||
pub fn server_mode(mut self) -> Self {
|
|
||||||
self.server_mode = true;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Append a header.
|
|
||||||
///
|
|
||||||
/// Header gets appended to existing header.
|
|
||||||
/// To override header use `set_header()` method.
|
|
||||||
pub fn header<K, V>(mut self, key: K, value: V) -> Self
|
|
||||||
where
|
|
||||||
HeaderName: TryFrom<K>,
|
|
||||||
HeaderValue: TryFrom<V>,
|
|
||||||
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
|
|
||||||
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
|
|
||||||
{
|
|
||||||
match HeaderName::try_from(key) {
|
|
||||||
Ok(key) => match HeaderValue::try_from(value) {
|
|
||||||
Ok(value) => {
|
|
||||||
self.head.headers.append(key, value);
|
|
||||||
}
|
|
||||||
Err(e) => self.err = Some(e.into()),
|
|
||||||
},
|
|
||||||
Err(e) => self.err = Some(e.into()),
|
|
||||||
}
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insert a header, replaces existing header.
|
|
||||||
pub fn set_header<K, V>(mut self, key: K, value: V) -> Self
|
|
||||||
where
|
|
||||||
HeaderName: TryFrom<K>,
|
|
||||||
HeaderValue: TryFrom<V>,
|
|
||||||
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
|
|
||||||
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
|
|
||||||
{
|
|
||||||
match HeaderName::try_from(key) {
|
|
||||||
Ok(key) => match HeaderValue::try_from(value) {
|
|
||||||
Ok(value) => {
|
|
||||||
self.head.headers.insert(key, value);
|
|
||||||
}
|
|
||||||
Err(e) => self.err = Some(e.into()),
|
|
||||||
},
|
|
||||||
Err(e) => self.err = Some(e.into()),
|
|
||||||
}
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insert a header only if it is not yet set.
|
|
||||||
pub fn set_header_if_none<K, V>(mut self, key: K, value: V) -> Self
|
|
||||||
where
|
|
||||||
HeaderName: TryFrom<K>,
|
|
||||||
HeaderValue: TryFrom<V>,
|
|
||||||
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
|
|
||||||
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
|
|
||||||
{
|
|
||||||
match HeaderName::try_from(key) {
|
|
||||||
Ok(key) => {
|
|
||||||
if !self.head.headers.contains_key(&key) {
|
|
||||||
match HeaderValue::try_from(value) {
|
|
||||||
Ok(value) => {
|
|
||||||
self.head.headers.insert(key, value);
|
|
||||||
}
|
|
||||||
Err(e) => self.err = Some(e.into()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => self.err = Some(e.into()),
|
|
||||||
}
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set HTTP basic authorization header
|
|
||||||
pub fn basic_auth<U>(self, username: U, password: Option<&str>) -> Self
|
|
||||||
where
|
|
||||||
U: fmt::Display,
|
|
||||||
{
|
|
||||||
let auth = match password {
|
|
||||||
Some(password) => format!("{}:{}", username, password),
|
|
||||||
None => format!("{}:", username),
|
|
||||||
};
|
|
||||||
self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set HTTP bearer authentication header
|
|
||||||
pub fn bearer_auth<T>(self, token: T) -> Self
|
|
||||||
where
|
|
||||||
T: fmt::Display,
|
|
||||||
{
|
|
||||||
self.header(AUTHORIZATION, format!("Bearer {}", token))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Complete request construction and connect to a websockets server.
|
|
||||||
pub async fn connect(mut self) -> Result<WsConnection, WsClientError> {
|
|
||||||
if let Some(e) = self.err.take() {
|
|
||||||
return Err(WsClientError::from(e));
|
|
||||||
}
|
|
||||||
|
|
||||||
// validate uri
|
|
||||||
let uri = &self.head.uri;
|
|
||||||
if uri.host().is_none() {
|
|
||||||
return Err(InvalidUrl::MissingHost.into());
|
|
||||||
} else if uri.scheme().is_none() {
|
|
||||||
return Err(InvalidUrl::MissingScheme.into());
|
|
||||||
} else if let Some(scheme) = uri.scheme() {
|
|
||||||
match scheme.as_str() {
|
|
||||||
"http" | "ws" | "https" | "wss" => (),
|
|
||||||
_ => return Err(InvalidUrl::UnknownScheme.into()),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Err(InvalidUrl::UnknownScheme.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
if !self.head.headers.contains_key(header::HOST) {
|
|
||||||
self.head.headers.insert(
|
|
||||||
header::HOST,
|
|
||||||
HeaderValue::from_str(uri.host().unwrap()).unwrap(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "cookie")]
|
|
||||||
{
|
|
||||||
use percent_encoding::percent_encode;
|
|
||||||
use std::fmt::Write as FmtWrite;
|
|
||||||
|
|
||||||
// set cookies
|
|
||||||
if let Some(ref mut jar) = self.cookies {
|
|
||||||
let mut cookie = String::new();
|
|
||||||
for c in jar.delta() {
|
|
||||||
let name = percent_encode(
|
|
||||||
c.name().as_bytes(),
|
|
||||||
crate::http::helpers::USERINFO,
|
|
||||||
);
|
|
||||||
let value = percent_encode(
|
|
||||||
c.value().as_bytes(),
|
|
||||||
crate::http::helpers::USERINFO,
|
|
||||||
);
|
|
||||||
let _ = write!(&mut cookie, "; {}={}", name, value);
|
|
||||||
}
|
|
||||||
self.head.headers.insert(
|
|
||||||
header::COOKIE,
|
|
||||||
HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// origin
|
|
||||||
if let Some(origin) = self.origin.take() {
|
|
||||||
self.head.headers.insert(header::ORIGIN, origin);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.head.set_connection_type(ConnectionType::Upgrade);
|
|
||||||
self.head
|
|
||||||
.headers
|
|
||||||
.insert(header::UPGRADE, HeaderValue::from_static("websocket"));
|
|
||||||
self.head.headers.insert(
|
|
||||||
header::SEC_WEBSOCKET_VERSION,
|
|
||||||
HeaderValue::from_static("13"),
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(protocols) = self.protocols.take() {
|
|
||||||
self.head.headers.insert(
|
|
||||||
header::SEC_WEBSOCKET_PROTOCOL,
|
|
||||||
HeaderValue::try_from(protocols.as_str()).unwrap(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a random key for the `Sec-WebSocket-Key` header.
|
|
||||||
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
|
||||||
// when decoded, is 16 bytes in length (RFC 6455)
|
|
||||||
let mut sec_key: [u8; 16] = [0; 16];
|
|
||||||
WyRand::new().fill(&mut sec_key);
|
|
||||||
let key = base64::encode(&sec_key);
|
|
||||||
|
|
||||||
self.head.headers.insert(
|
|
||||||
header::SEC_WEBSOCKET_KEY,
|
|
||||||
HeaderValue::try_from(key.as_str()).unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let head = self.head;
|
|
||||||
let max_size = self.max_size;
|
|
||||||
let server_mode = self.server_mode;
|
|
||||||
|
|
||||||
let fut = self.config.connector.open_tunnel(head.into(), self.addr);
|
|
||||||
|
|
||||||
// set request timeout
|
|
||||||
let (head, io, _) = if self.config.timeout.non_zero() {
|
|
||||||
timeout(self.config.timeout, fut)
|
|
||||||
.await
|
|
||||||
.map_err(|_| SendRequestError::Timeout)
|
|
||||||
.and_then(|res| res)?
|
|
||||||
} else {
|
|
||||||
fut.await?
|
|
||||||
};
|
|
||||||
|
|
||||||
// verify response
|
|
||||||
if head.status != StatusCode::SWITCHING_PROTOCOLS {
|
|
||||||
return Err(WsClientError::InvalidResponseStatus(head.status));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for "UPGRADE" to websocket header
|
|
||||||
let has_hdr = if let Some(hdr) = head.headers.get(&header::UPGRADE) {
|
|
||||||
if let Ok(s) = hdr.to_str() {
|
|
||||||
s.to_ascii_lowercase().contains("websocket")
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
if !has_hdr {
|
|
||||||
log::trace!("Invalid upgrade header");
|
|
||||||
return Err(WsClientError::InvalidUpgradeHeader);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for "CONNECTION" header
|
|
||||||
if let Some(conn) = head.headers.get(&header::CONNECTION) {
|
|
||||||
if let Ok(s) = conn.to_str() {
|
|
||||||
if !s.to_ascii_lowercase().contains("upgrade") {
|
|
||||||
log::trace!("Invalid connection header: {}", s);
|
|
||||||
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log::trace!("Invalid connection header: {:?}", conn);
|
|
||||||
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log::trace!("Missing connection header");
|
|
||||||
return Err(WsClientError::MissingConnectionHeader);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(hdr_key) = head.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
|
|
||||||
let encoded = ws::hash_key(key.as_ref());
|
|
||||||
if hdr_key.as_bytes() != encoded.as_bytes() {
|
|
||||||
log::trace!(
|
|
||||||
"Invalid challenge response: expected: {} received: {:?}",
|
|
||||||
encoded,
|
|
||||||
key
|
|
||||||
);
|
|
||||||
return Err(WsClientError::InvalidChallengeResponse(
|
|
||||||
encoded,
|
|
||||||
hdr_key.clone(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
|
|
||||||
return Err(WsClientError::MissingWebSocketAcceptHeader);
|
|
||||||
};
|
|
||||||
|
|
||||||
// response and ws io
|
|
||||||
Ok(WsConnection::new(
|
|
||||||
ClientResponse::new(head, Payload::None),
|
|
||||||
io,
|
|
||||||
if server_mode {
|
|
||||||
ws::Codec::new().max_size(max_size)
|
|
||||||
} else {
|
|
||||||
ws::Codec::new().max_size(max_size).client_mode()
|
|
||||||
},
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Debug for WsRequest {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
writeln!(
|
|
||||||
f,
|
|
||||||
"\nWebsocketsRequest {}:{}",
|
|
||||||
self.head.method, self.head.uri
|
|
||||||
)?;
|
|
||||||
writeln!(f, " headers:")?;
|
|
||||||
for (key, val) in self.head.headers.iter() {
|
|
||||||
writeln!(f, " {:?}: {:?}", key, val)?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct WsConnection {
|
|
||||||
io: IoBoxed,
|
|
||||||
codec: ws::Codec,
|
|
||||||
res: ClientResponse,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WsConnection {
|
|
||||||
fn new(res: ClientResponse, io: IoBoxed, codec: ws::Codec) -> Self {
|
|
||||||
Self { io, codec, res }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get ws sink
|
|
||||||
pub fn sink(&self) -> ws::WsSink {
|
|
||||||
ws::WsSink::new(self.io.get_ref(), self.codec.clone())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get reference to response
|
|
||||||
pub fn response(&self) -> &ClientResponse {
|
|
||||||
&self.res
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: fix close frame handling
|
|
||||||
/// Start client websockets with `SinkService` and `mpsc::Receiver<Frame>`
|
|
||||||
pub fn start_default(self) -> mpsc::Receiver<Result<ws::Frame, ws::WsError<()>>> {
|
|
||||||
let (tx, rx): (_, mpsc::Receiver<Result<ws::Frame, ws::WsError<()>>>) =
|
|
||||||
mpsc::channel();
|
|
||||||
|
|
||||||
rt::spawn(async move {
|
|
||||||
let io = self.io.get_ref();
|
|
||||||
let srv = sink::SinkService::new(tx.clone()).map(|_| None);
|
|
||||||
|
|
||||||
if let Err(err) = self
|
|
||||||
.start(into_service(move |item| {
|
|
||||||
let io = io.clone();
|
|
||||||
let close = matches!(item, ws::Frame::Close(_));
|
|
||||||
let fut = srv.call(Ok::<_, ws::WsError<()>>(item));
|
|
||||||
async move {
|
|
||||||
let result = fut.await.map_err(|_| ());
|
|
||||||
if close {
|
|
||||||
io.close();
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
let _ = tx.send(Err(err));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
rx
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Start client websockets service.
|
|
||||||
pub async fn start<T, U>(self, service: U) -> Result<(), ws::WsError<T::Error>>
|
|
||||||
where
|
|
||||||
T: Service<Request = ws::Frame, Response = Option<ws::Message>> + 'static,
|
|
||||||
U: IntoService<T>,
|
|
||||||
{
|
|
||||||
let service = apply_fn(
|
|
||||||
service.into_service().map_err(ws::WsError::Service),
|
|
||||||
|req, srv| match req {
|
|
||||||
DispatchItem::Item(item) => Either::Left(srv.call(item)),
|
|
||||||
DispatchItem::WBackPressureEnabled
|
|
||||||
| DispatchItem::WBackPressureDisabled => Either::Right(Ready::Ok(None)),
|
|
||||||
DispatchItem::KeepAliveTimeout => {
|
|
||||||
Either::Right(Ready::Err(ws::WsError::KeepAlive))
|
|
||||||
}
|
|
||||||
DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => {
|
|
||||||
Either::Right(Ready::Err(ws::WsError::Protocol(e)))
|
|
||||||
}
|
|
||||||
DispatchItem::Disconnect(Some(e)) => {
|
|
||||||
Either::Right(Ready::Err(ws::WsError::Io(e)))
|
|
||||||
}
|
|
||||||
DispatchItem::Disconnect(None) => {
|
|
||||||
Either::Right(Ready::Err(ws::WsError::Disconnected))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
Dispatcher::new(self.io, self.codec, service, Default::default()).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Consumes the `WsConnection`, returning it'as underlying I/O framed object
|
|
||||||
/// and response.
|
|
||||||
pub fn into_inner(self) -> (ClientResponse, IoBoxed, ws::Codec) {
|
|
||||||
(self.res, self.io, self.codec)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::http::client::Client;
|
|
||||||
|
|
||||||
#[crate::rt_test]
|
|
||||||
async fn test_debug() {
|
|
||||||
let request = Client::new().ws("/").header("x-test", "111");
|
|
||||||
let repr = format!("{:?}", request);
|
|
||||||
assert!(repr.contains("WebsocketsRequest"));
|
|
||||||
assert!(repr.contains("x-test"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[crate::rt_test]
|
|
||||||
async fn test_header_override() {
|
|
||||||
let req = Client::build()
|
|
||||||
.header(header::CONTENT_TYPE, "111")
|
|
||||||
.finish()
|
|
||||||
.ws("/")
|
|
||||||
.set_header(header::CONTENT_TYPE, "222");
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
req.head
|
|
||||||
.headers
|
|
||||||
.get(header::CONTENT_TYPE)
|
|
||||||
.unwrap()
|
|
||||||
.to_str()
|
|
||||||
.unwrap(),
|
|
||||||
"222"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[crate::rt_test]
|
|
||||||
async fn basic_auth() {
|
|
||||||
let req = Client::new()
|
|
||||||
.ws("/")
|
|
||||||
.basic_auth("username", Some("password"));
|
|
||||||
assert_eq!(
|
|
||||||
req.head
|
|
||||||
.headers
|
|
||||||
.get(header::AUTHORIZATION)
|
|
||||||
.unwrap()
|
|
||||||
.to_str()
|
|
||||||
.unwrap(),
|
|
||||||
"Basic dXNlcm5hbWU6cGFzc3dvcmQ="
|
|
||||||
);
|
|
||||||
|
|
||||||
let req = Client::new().ws("/").basic_auth("username", None);
|
|
||||||
assert_eq!(
|
|
||||||
req.head
|
|
||||||
.headers
|
|
||||||
.get(header::AUTHORIZATION)
|
|
||||||
.unwrap()
|
|
||||||
.to_str()
|
|
||||||
.unwrap(),
|
|
||||||
"Basic dXNlcm5hbWU6"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[crate::rt_test]
|
|
||||||
async fn bearer_auth() {
|
|
||||||
let req = Client::new().ws("/").bearer_auth("someS3cr3tAutht0k3n");
|
|
||||||
assert_eq!(
|
|
||||||
req.head
|
|
||||||
.headers
|
|
||||||
.get(header::AUTHORIZATION)
|
|
||||||
.unwrap()
|
|
||||||
.to_str()
|
|
||||||
.unwrap(),
|
|
||||||
"Bearer someS3cr3tAutht0k3n"
|
|
||||||
);
|
|
||||||
let _ = req.connect();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "cookie")]
|
|
||||||
#[crate::rt_test]
|
|
||||||
async fn basics() {
|
|
||||||
let req = Client::new()
|
|
||||||
.ws("http://localhost/")
|
|
||||||
.origin("test-origin")
|
|
||||||
.max_frame_size(100)
|
|
||||||
.server_mode()
|
|
||||||
.protocols(&["v1", "v2"])
|
|
||||||
.set_header_if_none(header::CONTENT_TYPE, "json")
|
|
||||||
.set_header_if_none(header::CONTENT_TYPE, "text")
|
|
||||||
.cookie(Cookie::build("cookie1", "value1").finish());
|
|
||||||
assert_eq!(
|
|
||||||
req.origin.as_ref().unwrap().to_str().unwrap(),
|
|
||||||
"test-origin"
|
|
||||||
);
|
|
||||||
assert_eq!(req.max_size, 100);
|
|
||||||
assert_eq!(req.server_mode, true);
|
|
||||||
assert_eq!(req.protocols, Some("v1,v2".to_string()));
|
|
||||||
assert_eq!(
|
|
||||||
req.head.headers.get(header::CONTENT_TYPE).unwrap(),
|
|
||||||
header::HeaderValue::from_static("json")
|
|
||||||
);
|
|
||||||
|
|
||||||
let _ = req.connect().await;
|
|
||||||
|
|
||||||
assert!(Client::new().ws("/").connect().await.is_err());
|
|
||||||
assert!(Client::new().ws("http:///test").connect().await.is_err());
|
|
||||||
assert!(Client::new().ws("hmm://test.com/").connect().await.is_err());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -86,10 +86,6 @@ pub enum ParseError {
|
||||||
Timeout,
|
Timeout,
|
||||||
/// An `InvalidInput` occurred while trying to parse incoming stream.
|
/// An `InvalidInput` occurred while trying to parse incoming stream.
|
||||||
InvalidInput(&'static str),
|
InvalidInput(&'static str),
|
||||||
/// An `io::Error` that occurred while trying to read or write to a network
|
|
||||||
/// stream.
|
|
||||||
#[display(fmt = "IO error: {}", _0)]
|
|
||||||
Io(io::Error),
|
|
||||||
/// Parsing a field as string failed
|
/// Parsing a field as string failed
|
||||||
#[display(fmt = "UTF8 error: {}", _0)]
|
#[display(fmt = "UTF8 error: {}", _0)]
|
||||||
Utf8(Utf8Error),
|
Utf8(Utf8Error),
|
||||||
|
@ -321,21 +317,8 @@ mod tests {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! from_and_cause {
|
|
||||||
($from:expr => $error:pat) => {
|
|
||||||
match ParseError::from($from) {
|
|
||||||
e @ $error => {
|
|
||||||
let desc = format!("{}", e);
|
|
||||||
assert_eq!(desc, format!("IO error: {}", $from));
|
|
||||||
}
|
|
||||||
_ => unreachable!("{:?}", $from),
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_from() {
|
fn test_from() {
|
||||||
from_and_cause!(io::Error::new(io::ErrorKind::Other, "other") => ParseError::Io(..));
|
|
||||||
from!(httparse::Error::HeaderName => ParseError::Header);
|
from!(httparse::Error::HeaderName => ParseError::Header);
|
||||||
from!(httparse::Error::HeaderName => ParseError::Header);
|
from!(httparse::Error::HeaderName => ParseError::Header);
|
||||||
from!(httparse::Error::HeaderValue => ParseError::Header);
|
from!(httparse::Error::HeaderValue => ParseError::Header);
|
||||||
|
|
|
@ -674,7 +674,6 @@ impl ChunkedState {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::http::error::ParseError;
|
|
||||||
use crate::http::header::{HeaderName, SET_COOKIE};
|
use crate::http::header::{HeaderName, SET_COOKIE};
|
||||||
use crate::http::{HttpMessage, Method, Version};
|
use crate::http::{HttpMessage, Method, Version};
|
||||||
use crate::util::{Bytes, BytesMut};
|
use crate::util::{Bytes, BytesMut};
|
||||||
|
@ -723,11 +722,7 @@ mod tests {
|
||||||
macro_rules! expect_parse_err {
|
macro_rules! expect_parse_err {
|
||||||
($e:expr) => {{
|
($e:expr) => {{
|
||||||
match MessageDecoder::<Request>::default().decode($e) {
|
match MessageDecoder::<Request>::default().decode($e) {
|
||||||
Err(err) => {
|
Err(_) => (),
|
||||||
if let ParseError::Io(_) = err {
|
|
||||||
unreachable!("Parse error expected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => unreachable!("Error expected"),
|
_ => unreachable!("Error expected"),
|
||||||
}
|
}
|
||||||
}};
|
}};
|
||||||
|
|
|
@ -45,7 +45,7 @@ pub struct RequestHead {
|
||||||
pub headers: HeaderMap,
|
pub headers: HeaderMap,
|
||||||
pub extensions: RefCell<Extensions>,
|
pub extensions: RefCell<Extensions>,
|
||||||
pub io: Option<IoRef>,
|
pub io: Option<IoRef>,
|
||||||
pub(super) flags: Flags,
|
pub(crate) flags: Flags,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RequestHead {
|
impl Default for RequestHead {
|
||||||
|
|
|
@ -4,13 +4,11 @@ use std::{convert::TryFrom, io, net, str::FromStr, sync::mpsc, thread};
|
||||||
#[cfg(feature = "cookie")]
|
#[cfg(feature = "cookie")]
|
||||||
use coo_kie::{Cookie, CookieJar};
|
use coo_kie::{Cookie, CookieJar};
|
||||||
|
|
||||||
use crate::{io::Io, rt::System, server::Server, service::ServiceFactory};
|
use crate::ws::{error::WsClientError, WsClient, WsConnection};
|
||||||
|
use crate::{io::Filter, io::Io, rt::System, server::Server, service::ServiceFactory};
|
||||||
use crate::{time::Millis, time::Seconds, util::Bytes};
|
use crate::{time::Millis, time::Seconds, util::Bytes};
|
||||||
|
|
||||||
use super::client::error::WsClientError;
|
use super::client::{Client, ClientRequest, ClientResponse, Connector};
|
||||||
use super::client::{
|
|
||||||
ws::WsConnection, Client, ClientRequest, ClientResponse, Connector,
|
|
||||||
};
|
|
||||||
use super::error::{HttpError, PayloadError};
|
use super::error::{HttpError, PayloadError};
|
||||||
use super::header::{HeaderMap, HeaderName, HeaderValue};
|
use super::header::{HeaderMap, HeaderName, HeaderValue};
|
||||||
use super::payload::Payload;
|
use super::payload::Payload;
|
||||||
|
@ -322,14 +320,54 @@ impl TestServer {
|
||||||
response.body().limit(10_485_760).await
|
response.body().limit(10_485_760).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Connect to websocket server at a given path
|
/// Connect to a websocket server
|
||||||
pub async fn ws_at(&mut self, path: &str) -> Result<WsConnection, WsClientError> {
|
pub async fn ws(&mut self) -> Result<WsConnection<impl Filter>, WsClientError> {
|
||||||
self.client.ws(self.url(path)).connect().await
|
self.ws_at("/").await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Connect to websocket server at a given path
|
||||||
|
pub async fn ws_at(
|
||||||
|
&mut self,
|
||||||
|
path: &str,
|
||||||
|
) -> Result<WsConnection<impl Filter>, WsClientError> {
|
||||||
|
WsClient::build(self.url(path))
|
||||||
|
.address(self.addr)
|
||||||
|
.timeout(Seconds(30))
|
||||||
|
.finish()
|
||||||
|
.unwrap()
|
||||||
|
.connect()
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "openssl")]
|
||||||
/// Connect to a websocket server
|
/// Connect to a websocket server
|
||||||
pub async fn ws(&mut self) -> Result<WsConnection, WsClientError> {
|
pub async fn wss(&mut self) -> Result<WsConnection<impl Filter>, WsClientError> {
|
||||||
self.ws_at("/").await
|
self.wss_at("/").await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "openssl")]
|
||||||
|
/// Connect to secure websocket server at a given path
|
||||||
|
pub async fn wss_at(
|
||||||
|
&mut self,
|
||||||
|
path: &str,
|
||||||
|
) -> Result<WsConnection<impl Filter>, WsClientError> {
|
||||||
|
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
|
||||||
|
|
||||||
|
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||||
|
builder.set_verify(SslVerifyMode::NONE);
|
||||||
|
let _ = builder
|
||||||
|
.set_alpn_protos(b"\x08http/1.1")
|
||||||
|
.map_err(|e| log::error!("Cannot set alpn protocol: {:?}", e));
|
||||||
|
|
||||||
|
WsClient::build(self.url(path))
|
||||||
|
.address(self.addr)
|
||||||
|
.timeout(Seconds(30))
|
||||||
|
.openssl(builder.build())
|
||||||
|
.take()
|
||||||
|
.finish()
|
||||||
|
.unwrap()
|
||||||
|
.connect()
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stop http server
|
/// Stop http server
|
||||||
|
|
|
@ -54,7 +54,7 @@ impl ResponseError for HandshakeError {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseError for crate::ws::ProtocolError {}
|
impl ResponseError for crate::ws::error::ProtocolError {}
|
||||||
|
|
||||||
/// Verify `WebSocket` handshake request and create handshake reponse.
|
/// Verify `WebSocket` handshake request and create handshake reponse.
|
||||||
// /// `protocols` is a sequence of known protocols. On successful handshake,
|
// /// `protocols` is a sequence of known protocols. On successful handshake,
|
||||||
|
|
|
@ -10,8 +10,7 @@ use serde::de::DeserializeOwned;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
use crate::http::body::MessageBody;
|
use crate::http::body::MessageBody;
|
||||||
use crate::http::client::error::WsClientError;
|
use crate::http::client::{Client, ClientRequest, ClientResponse, Connector};
|
||||||
use crate::http::client::{ws, Client, ClientRequest, ClientResponse, Connector};
|
|
||||||
use crate::http::error::{HttpError, PayloadError, ResponseError};
|
use crate::http::error::{HttpError, PayloadError, ResponseError};
|
||||||
use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
|
use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
|
||||||
use crate::http::test::TestRequest as HttpTestRequest;
|
use crate::http::test::TestRequest as HttpTestRequest;
|
||||||
|
@ -22,7 +21,8 @@ use crate::service::{
|
||||||
};
|
};
|
||||||
use crate::time::{sleep, Millis, Seconds};
|
use crate::time::{sleep, Millis, Seconds};
|
||||||
use crate::util::{next, Bytes, BytesMut, Extensions, Ready};
|
use crate::util::{next, Bytes, BytesMut, Extensions, Ready};
|
||||||
use crate::{rt::System, server::Server, Stream};
|
use crate::ws::{error::WsClientError, WsClient, WsConnection};
|
||||||
|
use crate::{io::Sealed, rt::System, server::Server, Stream};
|
||||||
|
|
||||||
use crate::web::config::AppConfig;
|
use crate::web::config::AppConfig;
|
||||||
use crate::web::error::{DefaultError, ErrorRenderer};
|
use crate::web::error::{DefaultError, ErrorRenderer};
|
||||||
|
@ -919,12 +919,50 @@ impl TestServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Connect to websocket server at a given path
|
/// Connect to websocket server at a given path
|
||||||
pub async fn ws_at(&self, path: &str) -> Result<ws::WsConnection, WsClientError> {
|
pub async fn ws_at(
|
||||||
self.client.ws(self.url(path)).connect().await
|
&self,
|
||||||
|
path: &str,
|
||||||
|
) -> Result<WsConnection<Sealed>, WsClientError> {
|
||||||
|
if self.ssl {
|
||||||
|
#[cfg(feature = "openssl")]
|
||||||
|
{
|
||||||
|
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
|
||||||
|
|
||||||
|
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||||
|
builder.set_verify(SslVerifyMode::NONE);
|
||||||
|
let _ = builder
|
||||||
|
.set_alpn_protos(b"\x08http/1.1")
|
||||||
|
.map_err(|e| log::error!("Cannot set alpn protocol: {:?}", e));
|
||||||
|
|
||||||
|
WsClient::build(self.url(path))
|
||||||
|
.address(self.addr)
|
||||||
|
.timeout(Seconds(30))
|
||||||
|
.openssl(builder.build())
|
||||||
|
.take()
|
||||||
|
.finish()
|
||||||
|
.unwrap()
|
||||||
|
.connect()
|
||||||
|
.await
|
||||||
|
.map(|ws| ws.seal())
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "openssl"))]
|
||||||
|
{
|
||||||
|
panic!("openssl feature is required")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
WsClient::build(self.url(path))
|
||||||
|
.address(self.addr)
|
||||||
|
.timeout(Seconds(30))
|
||||||
|
.finish()
|
||||||
|
.unwrap()
|
||||||
|
.connect()
|
||||||
|
.await
|
||||||
|
.map(|ws| ws.seal())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Connect to a websocket server
|
/// Connect to a websocket server
|
||||||
pub async fn ws(&self) -> Result<ws::WsConnection, WsClientError> {
|
pub async fn ws(&self) -> Result<WsConnection<Sealed>, WsClientError> {
|
||||||
self.ws_at("/").await
|
self.ws_at("/").await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
936
ntex/src/ws/client.rs
Normal file
936
ntex/src/ws/client.rs
Normal file
|
@ -0,0 +1,936 @@
|
||||||
|
//! Websockets client
|
||||||
|
use std::future::Future;
|
||||||
|
use std::{cell::RefCell, convert::TryFrom, fmt, marker, net, rc::Rc, str};
|
||||||
|
|
||||||
|
#[cfg(feature = "openssl")]
|
||||||
|
use crate::connect::openssl;
|
||||||
|
#[cfg(feature = "rustls")]
|
||||||
|
use crate::connect::rustls;
|
||||||
|
#[cfg(feature = "cookie")]
|
||||||
|
use coo_kie::{Cookie, CookieJar};
|
||||||
|
|
||||||
|
use nanorand::{Rng, WyRand};
|
||||||
|
|
||||||
|
use crate::connect::{Connect, ConnectError, Connector};
|
||||||
|
use crate::http::header::{self, HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
|
||||||
|
use crate::http::{body::BodySize, client::ClientResponse, error::HttpError, h1};
|
||||||
|
use crate::http::{ConnectionType, RequestHead, RequestHeadType, StatusCode, Uri};
|
||||||
|
use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Sealed};
|
||||||
|
use crate::service::{apply_fn, into_service, IntoService, Service};
|
||||||
|
use crate::util::{sink, Either, Ready};
|
||||||
|
use crate::{channel::mpsc, rt, time::timeout, time::Millis, ws};
|
||||||
|
|
||||||
|
use super::error::{WsClientBuilderError, WsClientError, WsError};
|
||||||
|
|
||||||
|
/// `WebSocket` client builder
|
||||||
|
pub struct WsClient<F, T> {
|
||||||
|
connector: T,
|
||||||
|
head: Rc<RequestHead>,
|
||||||
|
addr: Option<net::SocketAddr>,
|
||||||
|
max_size: usize,
|
||||||
|
server_mode: bool,
|
||||||
|
timeout: Millis,
|
||||||
|
extra_headers: RefCell<Option<HeaderMap>>,
|
||||||
|
_t: marker::PhantomData<F>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `WebSocket` client builder
|
||||||
|
pub struct WsClientBuilder<F, T> {
|
||||||
|
inner: Option<Inner<F, T>>,
|
||||||
|
err: Option<HttpError>,
|
||||||
|
protocols: Option<String>,
|
||||||
|
origin: Option<HeaderValue>,
|
||||||
|
#[cfg(feature = "cookie")]
|
||||||
|
cookies: Option<CookieJar>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Inner<F, T> {
|
||||||
|
connector: T,
|
||||||
|
pub(crate) head: RequestHead,
|
||||||
|
addr: Option<net::SocketAddr>,
|
||||||
|
max_size: usize,
|
||||||
|
server_mode: bool,
|
||||||
|
timeout: Millis,
|
||||||
|
_t: marker::PhantomData<F>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsClient<Base, ()> {
|
||||||
|
/// Create new websocket client builder
|
||||||
|
pub fn build<U>(
|
||||||
|
uri: U,
|
||||||
|
) -> WsClientBuilder<
|
||||||
|
Base,
|
||||||
|
impl Service<Request = Connect<Uri>, Response = Io, Error = ConnectError>,
|
||||||
|
>
|
||||||
|
where
|
||||||
|
Uri: TryFrom<U>,
|
||||||
|
<Uri as TryFrom<U>>::Error: Into<HttpError>,
|
||||||
|
{
|
||||||
|
WsClientBuilder::new(uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create new websocket client builder
|
||||||
|
pub fn with_connector<F, T, U>(uri: U, connector: T) -> WsClientBuilder<F, T>
|
||||||
|
where
|
||||||
|
Uri: TryFrom<U>,
|
||||||
|
<Uri as TryFrom<U>>::Error: Into<HttpError>,
|
||||||
|
F: Filter,
|
||||||
|
T: Service<Request = Connect<Uri>, Response = Io<F>, Error = ConnectError>,
|
||||||
|
{
|
||||||
|
WsClientBuilder::new(uri).connector(connector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, T> WsClient<F, T> {
|
||||||
|
/// Insert a header, replaces existing header.
|
||||||
|
pub fn set_header<K, V>(&self, key: K, value: V) -> Result<(), HttpError>
|
||||||
|
where
|
||||||
|
HeaderName: TryFrom<K>,
|
||||||
|
HeaderValue: TryFrom<V>,
|
||||||
|
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
|
||||||
|
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
|
||||||
|
{
|
||||||
|
let key = HeaderName::try_from(key).map_err(Into::into)?;
|
||||||
|
let value = HeaderValue::try_from(value).map_err(Into::into)?;
|
||||||
|
if let Some(headers) = &mut *self.extra_headers.borrow_mut() {
|
||||||
|
headers.insert(key, value);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(key, value);
|
||||||
|
*self.extra_headers.borrow_mut() = Some(headers);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set HTTP basic authorization header
|
||||||
|
pub fn set_basic_auth<U>(
|
||||||
|
&self,
|
||||||
|
username: U,
|
||||||
|
password: Option<&str>,
|
||||||
|
) -> Result<(), HttpError>
|
||||||
|
where
|
||||||
|
U: fmt::Display,
|
||||||
|
{
|
||||||
|
let auth = match password {
|
||||||
|
Some(password) => format!("{}:{}", username, password),
|
||||||
|
None => format!("{}:", username),
|
||||||
|
};
|
||||||
|
self.set_header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set HTTP bearer authentication header
|
||||||
|
pub fn set_bearer_auth<U>(&self, token: U) -> Result<(), HttpError>
|
||||||
|
where
|
||||||
|
U: fmt::Display,
|
||||||
|
{
|
||||||
|
self.set_header(AUTHORIZATION, format!("Bearer {}", token))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, T> WsClient<F, T>
|
||||||
|
where
|
||||||
|
F: Filter,
|
||||||
|
T: Service<Request = Connect<Uri>, Response = Io<F>, Error = ConnectError>,
|
||||||
|
{
|
||||||
|
/// Complete request construction and connect to a websockets server.
|
||||||
|
pub fn connect(
|
||||||
|
&self,
|
||||||
|
) -> impl Future<Output = Result<WsConnection<F>, WsClientError>> {
|
||||||
|
let head = self.head.clone();
|
||||||
|
let max_size = self.max_size;
|
||||||
|
let server_mode = self.server_mode;
|
||||||
|
let to = self.timeout;
|
||||||
|
let mut headers = self
|
||||||
|
.extra_headers
|
||||||
|
.borrow_mut()
|
||||||
|
.take()
|
||||||
|
.unwrap_or_else(HeaderMap::new);
|
||||||
|
|
||||||
|
// Generate a random key for the `Sec-WebSocket-Key` header.
|
||||||
|
// a base64-encoded (see Section 4 of [RFC4648]) value that,
|
||||||
|
// when decoded, is 16 bytes in length (RFC 6455)
|
||||||
|
let mut sec_key: [u8; 16] = [0; 16];
|
||||||
|
WyRand::new().fill(&mut sec_key);
|
||||||
|
let key = base64::encode(&sec_key);
|
||||||
|
|
||||||
|
headers.insert(
|
||||||
|
header::SEC_WEBSOCKET_KEY,
|
||||||
|
HeaderValue::try_from(key.as_str()).unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let msg = Connect::new(head.uri.clone()).set_addr(self.addr);
|
||||||
|
let fut = self.connector.call(msg);
|
||||||
|
|
||||||
|
async move {
|
||||||
|
let io = fut.await?;
|
||||||
|
|
||||||
|
// create Framed and send request
|
||||||
|
let codec = h1::ClientCodec::default();
|
||||||
|
|
||||||
|
// send request and read response
|
||||||
|
let fut = async {
|
||||||
|
io.send(
|
||||||
|
&codec,
|
||||||
|
(RequestHeadType::Rc(head, Some(headers)), BodySize::None).into(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
io.recv(&codec)
|
||||||
|
.await?
|
||||||
|
.ok_or(WsClientError::Disconnected(None))
|
||||||
|
};
|
||||||
|
|
||||||
|
// set request timeout
|
||||||
|
let response = if to.non_zero() {
|
||||||
|
timeout(to, fut)
|
||||||
|
.await
|
||||||
|
.map_err(|_| WsClientError::Timeout)
|
||||||
|
.and_then(|res| res)?
|
||||||
|
} else {
|
||||||
|
fut.await?
|
||||||
|
};
|
||||||
|
|
||||||
|
// verify response
|
||||||
|
if response.status != StatusCode::SWITCHING_PROTOCOLS {
|
||||||
|
return Err(WsClientError::InvalidResponseStatus(response.status));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for "UPGRADE" to websocket header
|
||||||
|
let has_hdr = if let Some(hdr) = response.headers.get(&header::UPGRADE) {
|
||||||
|
if let Ok(s) = hdr.to_str() {
|
||||||
|
s.to_ascii_lowercase().contains("websocket")
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
if !has_hdr {
|
||||||
|
log::trace!("Invalid upgrade header");
|
||||||
|
return Err(WsClientError::InvalidUpgradeHeader);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for "CONNECTION" header
|
||||||
|
if let Some(conn) = response.headers.get(&header::CONNECTION) {
|
||||||
|
if let Ok(s) = conn.to_str() {
|
||||||
|
if !s.to_ascii_lowercase().contains("upgrade") {
|
||||||
|
log::trace!("Invalid connection header: {}", s);
|
||||||
|
return Err(WsClientError::InvalidConnectionHeader(
|
||||||
|
conn.clone(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log::trace!("Invalid connection header: {:?}", conn);
|
||||||
|
return Err(WsClientError::InvalidConnectionHeader(conn.clone()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log::trace!("Missing connection header");
|
||||||
|
return Err(WsClientError::MissingConnectionHeader);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(hdr_key) = response.headers.get(&header::SEC_WEBSOCKET_ACCEPT) {
|
||||||
|
let encoded = ws::hash_key(key.as_ref());
|
||||||
|
if hdr_key.as_bytes() != encoded.as_bytes() {
|
||||||
|
log::trace!(
|
||||||
|
"Invalid challenge response: expected: {} received: {:?}",
|
||||||
|
encoded,
|
||||||
|
key
|
||||||
|
);
|
||||||
|
return Err(WsClientError::InvalidChallengeResponse(
|
||||||
|
encoded,
|
||||||
|
hdr_key.clone(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log::trace!("Missing SEC-WEBSOCKET-ACCEPT header");
|
||||||
|
return Err(WsClientError::MissingWebSocketAcceptHeader);
|
||||||
|
};
|
||||||
|
|
||||||
|
// response and ws io
|
||||||
|
Ok(WsConnection::new(
|
||||||
|
io,
|
||||||
|
ClientResponse::with_empty_payload(response),
|
||||||
|
if server_mode {
|
||||||
|
ws::Codec::new().max_size(max_size)
|
||||||
|
} else {
|
||||||
|
ws::Codec::new().max_size(max_size).client_mode()
|
||||||
|
},
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, T> fmt::Debug for WsClient<F, T> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
writeln!(f, "\nWsClient {}:{}", self.head.method, self.head.uri)?;
|
||||||
|
writeln!(f, " headers:")?;
|
||||||
|
for (key, val) in self.head.headers.iter() {
|
||||||
|
writeln!(f, " {:?}: {:?}", key, val)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsClientBuilder<Base, ()> {
|
||||||
|
/// Create new websocket connector
|
||||||
|
fn new<U>(
|
||||||
|
uri: U,
|
||||||
|
) -> WsClientBuilder<
|
||||||
|
Base,
|
||||||
|
impl Service<Request = Connect<Uri>, Response = Io, Error = ConnectError>,
|
||||||
|
>
|
||||||
|
where
|
||||||
|
Uri: TryFrom<U>,
|
||||||
|
<Uri as TryFrom<U>>::Error: Into<HttpError>,
|
||||||
|
{
|
||||||
|
let (head, err) = match Uri::try_from(uri) {
|
||||||
|
Ok(uri) => (
|
||||||
|
RequestHead {
|
||||||
|
uri,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
Err(e) => (Default::default(), Some(e.into())),
|
||||||
|
};
|
||||||
|
|
||||||
|
WsClientBuilder {
|
||||||
|
err,
|
||||||
|
origin: None,
|
||||||
|
protocols: None,
|
||||||
|
inner: Some(Inner {
|
||||||
|
head,
|
||||||
|
connector: Connector::<Uri>::default(),
|
||||||
|
addr: None,
|
||||||
|
max_size: 65_536,
|
||||||
|
server_mode: false,
|
||||||
|
timeout: Millis(5_000),
|
||||||
|
_t: marker::PhantomData,
|
||||||
|
}),
|
||||||
|
#[cfg(feature = "cookie")]
|
||||||
|
cookies: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, T> WsClientBuilder<F, T>
|
||||||
|
where
|
||||||
|
T: Service<Request = Connect<Uri>, Response = Io<F>, Error = ConnectError>,
|
||||||
|
{
|
||||||
|
/// Set socket address of the server.
|
||||||
|
///
|
||||||
|
/// This address is used for connection. If address is not
|
||||||
|
/// provided url's host name get resolved.
|
||||||
|
pub fn address(&mut self, addr: net::SocketAddr) -> &mut Self {
|
||||||
|
if let Some(parts) = parts(&mut self.inner, &self.err) {
|
||||||
|
parts.addr = Some(addr);
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set supported websocket protocols
|
||||||
|
pub fn protocols<U, V>(&mut self, protos: U) -> &mut Self
|
||||||
|
where
|
||||||
|
U: IntoIterator<Item = V>,
|
||||||
|
V: AsRef<str>,
|
||||||
|
{
|
||||||
|
let mut protos = protos
|
||||||
|
.into_iter()
|
||||||
|
.fold(String::new(), |acc, s| acc + s.as_ref() + ",");
|
||||||
|
protos.pop();
|
||||||
|
self.protocols = Some(protos);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cookie")]
|
||||||
|
/// Set a cookie
|
||||||
|
pub fn cookie(&mut self, cookie: Cookie<'_>) -> &mut Self {
|
||||||
|
if self.cookies.is_none() {
|
||||||
|
let mut jar = CookieJar::new();
|
||||||
|
jar.add(cookie.into_owned());
|
||||||
|
self.cookies = Some(jar)
|
||||||
|
} else {
|
||||||
|
self.cookies.as_mut().unwrap().add(cookie.into_owned());
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set request Origin
|
||||||
|
pub fn origin<V, E>(&mut self, origin: V) -> &mut Self
|
||||||
|
where
|
||||||
|
HeaderValue: TryFrom<V, Error = E>,
|
||||||
|
HttpError: From<E>,
|
||||||
|
{
|
||||||
|
match HeaderValue::try_from(origin) {
|
||||||
|
Ok(value) => self.origin = Some(value),
|
||||||
|
Err(e) => self.err = Some(e.into()),
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set max frame size
|
||||||
|
///
|
||||||
|
/// By default max size is set to 64kb
|
||||||
|
pub fn max_frame_size(&mut self, size: usize) -> &mut Self {
|
||||||
|
if let Some(parts) = parts(&mut self.inner, &self.err) {
|
||||||
|
parts.max_size = size;
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Disable payload masking. By default ws client masks frame payload.
|
||||||
|
pub fn server_mode(&mut self) -> &mut Self {
|
||||||
|
if let Some(parts) = parts(&mut self.inner, &self.err) {
|
||||||
|
parts.server_mode = true;
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append a header.
|
||||||
|
///
|
||||||
|
/// Header gets appended to existing header.
|
||||||
|
/// To override header use `set_header()` method.
|
||||||
|
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
|
||||||
|
where
|
||||||
|
HeaderName: TryFrom<K>,
|
||||||
|
HeaderValue: TryFrom<V>,
|
||||||
|
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
|
||||||
|
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
|
||||||
|
{
|
||||||
|
if let Some(parts) = parts(&mut self.inner, &self.err) {
|
||||||
|
match HeaderName::try_from(key) {
|
||||||
|
Ok(key) => match HeaderValue::try_from(value) {
|
||||||
|
Ok(value) => {
|
||||||
|
parts.head.headers.append(key, value);
|
||||||
|
}
|
||||||
|
Err(e) => self.err = Some(e.into()),
|
||||||
|
},
|
||||||
|
Err(e) => self.err = Some(e.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert a header, replaces existing header.
|
||||||
|
pub fn set_header<K, V>(&mut self, key: K, value: V) -> &mut Self
|
||||||
|
where
|
||||||
|
HeaderName: TryFrom<K>,
|
||||||
|
HeaderValue: TryFrom<V>,
|
||||||
|
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
|
||||||
|
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
|
||||||
|
{
|
||||||
|
if let Some(parts) = parts(&mut self.inner, &self.err) {
|
||||||
|
match HeaderName::try_from(key) {
|
||||||
|
Ok(key) => match HeaderValue::try_from(value) {
|
||||||
|
Ok(value) => {
|
||||||
|
parts.head.headers.insert(key, value);
|
||||||
|
}
|
||||||
|
Err(e) => self.err = Some(e.into()),
|
||||||
|
},
|
||||||
|
Err(e) => self.err = Some(e.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert a header only if it is not yet set.
|
||||||
|
pub fn set_header_if_none<K, V>(&mut self, key: K, value: V) -> &mut Self
|
||||||
|
where
|
||||||
|
HeaderName: TryFrom<K>,
|
||||||
|
HeaderValue: TryFrom<V>,
|
||||||
|
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
|
||||||
|
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
|
||||||
|
{
|
||||||
|
if let Some(parts) = parts(&mut self.inner, &self.err) {
|
||||||
|
match HeaderName::try_from(key) {
|
||||||
|
Ok(key) => {
|
||||||
|
if !parts.head.headers.contains_key(&key) {
|
||||||
|
match HeaderValue::try_from(value) {
|
||||||
|
Ok(value) => {
|
||||||
|
parts.head.headers.insert(key, value);
|
||||||
|
}
|
||||||
|
Err(e) => self.err = Some(e.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => self.err = Some(e.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set HTTP basic authorization header
|
||||||
|
pub fn basic_auth<U>(&mut self, username: U, password: Option<&str>) -> &mut Self
|
||||||
|
where
|
||||||
|
U: fmt::Display,
|
||||||
|
{
|
||||||
|
let auth = match password {
|
||||||
|
Some(password) => format!("{}:{}", username, password),
|
||||||
|
None => format!("{}:", username),
|
||||||
|
};
|
||||||
|
self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set HTTP bearer authentication header
|
||||||
|
pub fn bearer_auth<U>(&mut self, token: U) -> &mut Self
|
||||||
|
where
|
||||||
|
U: fmt::Display,
|
||||||
|
{
|
||||||
|
self.header(AUTHORIZATION, format!("Bearer {}", token))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set request timeout.
|
||||||
|
///
|
||||||
|
/// Request timeout is the total time before a response must be received.
|
||||||
|
/// Default value is 5 seconds.
|
||||||
|
pub fn timeout<U: Into<Millis>>(&mut self, timeout: U) -> &mut Self {
|
||||||
|
if let Some(parts) = parts(&mut self.inner, &self.err) {
|
||||||
|
parts.timeout = timeout.into();
|
||||||
|
}
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use custom connector
|
||||||
|
pub fn connector<F1, T1>(&mut self, connector: T1) -> WsClientBuilder<F1, T1>
|
||||||
|
where
|
||||||
|
F1: Filter,
|
||||||
|
T1: Service<Request = Connect<Uri>, Response = Io<F1>, Error = ConnectError>,
|
||||||
|
{
|
||||||
|
let inner = self.inner.take().expect("cannot reuse WsClient builder");
|
||||||
|
|
||||||
|
WsClientBuilder {
|
||||||
|
inner: Some(Inner {
|
||||||
|
connector,
|
||||||
|
head: inner.head,
|
||||||
|
addr: inner.addr,
|
||||||
|
max_size: inner.max_size,
|
||||||
|
server_mode: inner.server_mode,
|
||||||
|
timeout: inner.timeout,
|
||||||
|
_t: marker::PhantomData,
|
||||||
|
}),
|
||||||
|
err: self.err.take(),
|
||||||
|
protocols: self.protocols.take(),
|
||||||
|
origin: self.origin.take(),
|
||||||
|
#[cfg(feature = "cookie")]
|
||||||
|
cookies: self.cookies.take(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "openssl")]
|
||||||
|
/// Use openssl connector.
|
||||||
|
pub fn openssl(
|
||||||
|
&mut self,
|
||||||
|
connector: openssl::SslConnector,
|
||||||
|
) -> WsClientBuilder<openssl::SslFilter, openssl::Connector<Uri>> {
|
||||||
|
self.connector(openssl::Connector::new(connector))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "rustls")]
|
||||||
|
/// Use rustls connector.
|
||||||
|
pub fn rustls(
|
||||||
|
&mut self,
|
||||||
|
config: std::sync::Arc<rustls::ClientConfig>,
|
||||||
|
) -> WsClientBuilder<rustls::TlsFilter, rustls::Connector<Uri>> {
|
||||||
|
self.connector(rustls::Connector::from(config))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This method construct new `WsClientBuilder`
|
||||||
|
pub fn take(&mut self) -> WsClientBuilder<F, T> {
|
||||||
|
WsClientBuilder {
|
||||||
|
inner: self.inner.take(),
|
||||||
|
err: self.err.take(),
|
||||||
|
origin: self.origin.take(),
|
||||||
|
protocols: self.protocols.take(),
|
||||||
|
#[cfg(feature = "cookie")]
|
||||||
|
cookies: self.cookies.take(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Complete building process and construct websockets client.
|
||||||
|
pub fn finish(&mut self) -> Result<WsClient<F, T>, WsClientBuilderError> {
|
||||||
|
if let Some(e) = self.err.take() {
|
||||||
|
return Err(WsClientBuilderError::Http(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[allow(unused_mut)]
|
||||||
|
let mut inner = self.inner.take().expect("cannot reuse WsClient builder");
|
||||||
|
|
||||||
|
// validate uri
|
||||||
|
let uri = &inner.head.uri;
|
||||||
|
if uri.host().is_none() {
|
||||||
|
return Err(WsClientBuilderError::MissingHost);
|
||||||
|
} else if uri.scheme().is_none() {
|
||||||
|
return Err(WsClientBuilderError::MissingScheme);
|
||||||
|
} else if let Some(scheme) = uri.scheme() {
|
||||||
|
match scheme.as_str() {
|
||||||
|
"http" | "ws" | "https" | "wss" => (),
|
||||||
|
_ => return Err(WsClientBuilderError::UnknownScheme),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(WsClientBuilderError::UnknownScheme);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !inner.head.headers.contains_key(header::HOST) {
|
||||||
|
inner.head.headers.insert(
|
||||||
|
header::HOST,
|
||||||
|
HeaderValue::from_str(uri.host().unwrap()).unwrap(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cookie")]
|
||||||
|
{
|
||||||
|
use percent_encoding::percent_encode;
|
||||||
|
use std::fmt::Write as FmtWrite;
|
||||||
|
|
||||||
|
// set cookies
|
||||||
|
if let Some(ref mut jar) = self.cookies {
|
||||||
|
let mut cookie = String::new();
|
||||||
|
for c in jar.delta() {
|
||||||
|
let name = percent_encode(
|
||||||
|
c.name().as_bytes(),
|
||||||
|
crate::http::helpers::USERINFO,
|
||||||
|
);
|
||||||
|
let value = percent_encode(
|
||||||
|
c.value().as_bytes(),
|
||||||
|
crate::http::helpers::USERINFO,
|
||||||
|
);
|
||||||
|
let _ = write!(&mut cookie, "; {}={}", name, value);
|
||||||
|
}
|
||||||
|
inner.head.headers.insert(
|
||||||
|
header::COOKIE,
|
||||||
|
HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// origin
|
||||||
|
if let Some(origin) = self.origin.take() {
|
||||||
|
inner.head.headers.insert(header::ORIGIN, origin);
|
||||||
|
}
|
||||||
|
|
||||||
|
inner.head.set_connection_type(ConnectionType::Upgrade);
|
||||||
|
inner
|
||||||
|
.head
|
||||||
|
.headers
|
||||||
|
.insert(header::UPGRADE, HeaderValue::from_static("websocket"));
|
||||||
|
inner.head.headers.insert(
|
||||||
|
header::SEC_WEBSOCKET_VERSION,
|
||||||
|
HeaderValue::from_static("13"),
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(protocols) = self.protocols.take() {
|
||||||
|
inner.head.headers.insert(
|
||||||
|
header::SEC_WEBSOCKET_PROTOCOL,
|
||||||
|
HeaderValue::try_from(protocols.as_str()).unwrap(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(WsClient {
|
||||||
|
connector: inner.connector,
|
||||||
|
head: Rc::new(inner.head),
|
||||||
|
addr: inner.addr,
|
||||||
|
max_size: inner.max_size,
|
||||||
|
server_mode: inner.server_mode,
|
||||||
|
timeout: inner.timeout,
|
||||||
|
extra_headers: RefCell::new(None),
|
||||||
|
_t: marker::PhantomData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn parts<'a, F, T>(
|
||||||
|
parts: &'a mut Option<Inner<F, T>>,
|
||||||
|
err: &Option<HttpError>,
|
||||||
|
) -> Option<&'a mut Inner<F, T>> {
|
||||||
|
if err.is_some() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
parts.as_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F, T> fmt::Debug for WsClientBuilder<F, T> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
if let Some(ref parts) = self.inner {
|
||||||
|
writeln!(
|
||||||
|
f,
|
||||||
|
"\nWsClientBuilder {}:{}",
|
||||||
|
parts.head.method, parts.head.uri
|
||||||
|
)?;
|
||||||
|
writeln!(f, " headers:")?;
|
||||||
|
for (key, val) in parts.head.headers.iter() {
|
||||||
|
writeln!(f, " {:?}: {:?}", key, val)?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
writeln!(f, "WsClientBuilder(Consumed)")?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WsConnection<F> {
|
||||||
|
io: Io<F>,
|
||||||
|
codec: ws::Codec,
|
||||||
|
res: ClientResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> WsConnection<F> {
|
||||||
|
fn new(io: Io<F>, res: ClientResponse, codec: ws::Codec) -> Self {
|
||||||
|
Self { io, codec, res }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get reference to response
|
||||||
|
pub fn response(&self) -> &ClientResponse {
|
||||||
|
&self.res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> WsConnection<F> {
|
||||||
|
/// Get ws sink
|
||||||
|
pub fn sink(&self) -> ws::WsSink {
|
||||||
|
ws::WsSink::new(self.io.get_ref(), self.codec.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the `WsConnection`, returning it'as underlying I/O stream object
|
||||||
|
/// and response.
|
||||||
|
pub fn into_inner(self) -> (Io<F>, ws::Codec, ClientResponse) {
|
||||||
|
(self.io, self.codec, self.res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsConnection<Sealed> {
|
||||||
|
// TODO: fix close frame handling
|
||||||
|
/// Start client websockets with `SinkService` and `mpsc::Receiver<Frame>`
|
||||||
|
pub fn start_default(self) -> mpsc::Receiver<Result<ws::Frame, WsError<()>>> {
|
||||||
|
let (tx, rx): (_, mpsc::Receiver<Result<ws::Frame, WsError<()>>>) =
|
||||||
|
mpsc::channel();
|
||||||
|
|
||||||
|
rt::spawn(async move {
|
||||||
|
let io = self.io.get_ref();
|
||||||
|
let srv = sink::SinkService::new(tx.clone()).map(|_| None);
|
||||||
|
|
||||||
|
if let Err(err) = self
|
||||||
|
.start(into_service(move |item| {
|
||||||
|
let io = io.clone();
|
||||||
|
let close = matches!(item, ws::Frame::Close(_));
|
||||||
|
let fut = srv.call(Ok::<_, WsError<()>>(item));
|
||||||
|
async move {
|
||||||
|
let result = fut.await.map_err(|_| ());
|
||||||
|
if close {
|
||||||
|
io.close();
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
let _ = tx.send(Err(err));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
rx
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start client websockets service.
|
||||||
|
pub async fn start<T, U>(self, service: U) -> Result<(), WsError<T::Error>>
|
||||||
|
where
|
||||||
|
T: Service<Request = ws::Frame, Response = Option<ws::Message>> + 'static,
|
||||||
|
U: IntoService<T>,
|
||||||
|
{
|
||||||
|
let service = apply_fn(
|
||||||
|
service.into_service().map_err(WsError::Service),
|
||||||
|
|req, srv| match req {
|
||||||
|
DispatchItem::Item(item) => Either::Left(srv.call(item)),
|
||||||
|
DispatchItem::WBackPressureEnabled
|
||||||
|
| DispatchItem::WBackPressureDisabled => Either::Right(Ready::Ok(None)),
|
||||||
|
DispatchItem::KeepAliveTimeout => {
|
||||||
|
Either::Right(Ready::Err(WsError::KeepAlive))
|
||||||
|
}
|
||||||
|
DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => {
|
||||||
|
Either::Right(Ready::Err(WsError::Protocol(e)))
|
||||||
|
}
|
||||||
|
DispatchItem::Disconnect(e) => {
|
||||||
|
Either::Right(Ready::Err(WsError::Disconnected(e)))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
Dispatcher::new(self.io, self.codec, service, Default::default()).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Filter> WsConnection<F> {
|
||||||
|
/// Convert I/O stream to boxed stream;
|
||||||
|
pub fn seal(self) -> WsConnection<Sealed> {
|
||||||
|
WsConnection {
|
||||||
|
io: self.io.seal(),
|
||||||
|
codec: self.codec,
|
||||||
|
res: self.res,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[crate::rt_test]
|
||||||
|
async fn test_debug() {
|
||||||
|
let mut builder = WsClient::build("http://localhost")
|
||||||
|
.header("x-test", "111")
|
||||||
|
.take();
|
||||||
|
let repr = format!("{:?}", builder);
|
||||||
|
assert!(repr.contains("WsClientBuilder"));
|
||||||
|
assert!(repr.contains("x-test"));
|
||||||
|
|
||||||
|
let client = builder.finish().unwrap();
|
||||||
|
let repr = format!("{:?}", client);
|
||||||
|
assert!(repr.contains("WsClient"));
|
||||||
|
assert!(repr.contains("x-test"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[crate::rt_test]
|
||||||
|
async fn header_override() {
|
||||||
|
let req = WsClient::build("http://localhost")
|
||||||
|
.header(header::CONTENT_TYPE, "111")
|
||||||
|
.set_header(header::CONTENT_TYPE, "222")
|
||||||
|
.finish()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
req.head
|
||||||
|
.headers
|
||||||
|
.get(header::CONTENT_TYPE)
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap(),
|
||||||
|
"222"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn basic_errs() {
|
||||||
|
let err = WsClient::build("localhost").finish().err().unwrap();
|
||||||
|
assert!(matches!(err, WsClientBuilderError::MissingScheme));
|
||||||
|
let err = WsClient::build("unknown://localhost")
|
||||||
|
.finish()
|
||||||
|
.err()
|
||||||
|
.unwrap();
|
||||||
|
assert!(matches!(err, WsClientBuilderError::UnknownScheme));
|
||||||
|
let err = WsClient::build("/").finish().err().unwrap();
|
||||||
|
assert!(matches!(err, WsClientBuilderError::MissingHost));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[crate::rt_test]
|
||||||
|
async fn basic_auth() {
|
||||||
|
let client = WsClient::build("http://localhost")
|
||||||
|
.basic_auth("username", Some("password"))
|
||||||
|
.finish()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
client
|
||||||
|
.head
|
||||||
|
.headers
|
||||||
|
.get(header::AUTHORIZATION)
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap(),
|
||||||
|
"Basic dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||||
|
);
|
||||||
|
|
||||||
|
let client = WsClient::build("http://localhost")
|
||||||
|
.basic_auth("username", None)
|
||||||
|
.finish()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
client
|
||||||
|
.head
|
||||||
|
.headers
|
||||||
|
.get(header::AUTHORIZATION)
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap(),
|
||||||
|
"Basic dXNlcm5hbWU6"
|
||||||
|
);
|
||||||
|
|
||||||
|
client.set_basic_auth("username", Some("password")).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
client
|
||||||
|
.extra_headers
|
||||||
|
.borrow()
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.get(header::AUTHORIZATION)
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap(),
|
||||||
|
"Basic dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[crate::rt_test]
|
||||||
|
async fn bearer_auth() {
|
||||||
|
let client = WsClient::build("http://localhost")
|
||||||
|
.bearer_auth("someS3cr3tAutht0k3n")
|
||||||
|
.finish()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
client
|
||||||
|
.head
|
||||||
|
.headers
|
||||||
|
.get(header::AUTHORIZATION)
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap(),
|
||||||
|
"Bearer someS3cr3tAutht0k3n"
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = client.set_bearer_auth("someS3cr3tAutht0k2n");
|
||||||
|
assert_eq!(
|
||||||
|
client
|
||||||
|
.extra_headers
|
||||||
|
.borrow()
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.get(header::AUTHORIZATION)
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap(),
|
||||||
|
"Bearer someS3cr3tAutht0k2n"
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = client.connect();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cookie")]
|
||||||
|
#[crate::rt_test]
|
||||||
|
async fn basics() {
|
||||||
|
let mut builder = WsClient::build("http://localhost/")
|
||||||
|
.origin("test-origin")
|
||||||
|
.max_frame_size(100)
|
||||||
|
.server_mode()
|
||||||
|
.protocols(&["v1", "v2"])
|
||||||
|
.set_header_if_none(header::CONTENT_TYPE, "json")
|
||||||
|
.set_header_if_none(header::CONTENT_TYPE, "text")
|
||||||
|
.cookie(Cookie::build("cookie1", "value1").finish())
|
||||||
|
.take();
|
||||||
|
assert_eq!(
|
||||||
|
builder.origin.as_ref().unwrap().to_str().unwrap(),
|
||||||
|
"test-origin"
|
||||||
|
);
|
||||||
|
assert_eq!(builder.inner.as_ref().unwrap().max_size, 100);
|
||||||
|
assert_eq!(builder.inner.as_ref().unwrap().server_mode, true);
|
||||||
|
assert_eq!(builder.protocols, Some("v1,v2".to_string()));
|
||||||
|
|
||||||
|
let client = builder.finish().unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
client.head.headers.get(header::CONTENT_TYPE).unwrap(),
|
||||||
|
header::HeaderValue::from_static("json")
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = client.connect().await;
|
||||||
|
|
||||||
|
assert!(WsClient::build("/").finish().is_err());
|
||||||
|
assert!(WsClient::build("http:///test").finish().is_err());
|
||||||
|
assert!(WsClient::build("hmm://test.com/").finish().is_err());
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,9 +3,9 @@ use std::cell::Cell;
|
||||||
use crate::codec::{Decoder, Encoder};
|
use crate::codec::{Decoder, Encoder};
|
||||||
use crate::util::{ByteString, Bytes, BytesMut};
|
use crate::util::{ByteString, Bytes, BytesMut};
|
||||||
|
|
||||||
|
use super::error::ProtocolError;
|
||||||
use super::frame::Parser;
|
use super::frame::Parser;
|
||||||
use super::proto::{CloseReason, OpCode};
|
use super::proto::{CloseReason, OpCode};
|
||||||
use super::ProtocolError;
|
|
||||||
|
|
||||||
/// WebSocket message
|
/// WebSocket message
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
|
|
128
ntex/src/ws/error.rs
Normal file
128
ntex/src/ws/error.rs
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
//! WebSocket protocol related errors.
|
||||||
|
use std::{error, io};
|
||||||
|
|
||||||
|
use derive_more::{Display, From};
|
||||||
|
|
||||||
|
use crate::connect::ConnectError;
|
||||||
|
use crate::http::error::{HttpError, ParseError};
|
||||||
|
use crate::http::{header::HeaderValue, StatusCode};
|
||||||
|
use crate::util::Either;
|
||||||
|
|
||||||
|
use super::OpCode;
|
||||||
|
|
||||||
|
/// Websocket service errors
|
||||||
|
#[derive(Debug, Display)]
|
||||||
|
pub enum WsError<E> {
|
||||||
|
Service(E),
|
||||||
|
/// Keep-alive error
|
||||||
|
KeepAlive,
|
||||||
|
/// Ws protocol level error
|
||||||
|
Protocol(ProtocolError),
|
||||||
|
/// Peer has been disconnected
|
||||||
|
#[display(fmt = "Peer has been disconnected: {:?}", _0)]
|
||||||
|
Disconnected(Option<io::Error>),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Websocket protocol errors
|
||||||
|
#[derive(Debug, Display, From)]
|
||||||
|
pub enum ProtocolError {
|
||||||
|
/// Received an unmasked frame from client
|
||||||
|
#[display(fmt = "Received an unmasked frame from client")]
|
||||||
|
UnmaskedFrame,
|
||||||
|
/// Received a masked frame from server
|
||||||
|
#[display(fmt = "Received a masked frame from server")]
|
||||||
|
MaskedFrame,
|
||||||
|
/// Encountered invalid opcode
|
||||||
|
#[display(fmt = "Invalid opcode: {}", _0)]
|
||||||
|
InvalidOpcode(u8),
|
||||||
|
/// Invalid control frame length
|
||||||
|
#[display(fmt = "Invalid control frame length: {}", _0)]
|
||||||
|
InvalidLength(usize),
|
||||||
|
/// Bad web socket op code
|
||||||
|
#[display(fmt = "Bad web socket op code")]
|
||||||
|
BadOpCode,
|
||||||
|
/// A payload reached size limit.
|
||||||
|
#[display(fmt = "A payload reached size limit.")]
|
||||||
|
Overflow,
|
||||||
|
/// Continuation is not started
|
||||||
|
#[display(fmt = "Continuation is not started.")]
|
||||||
|
ContinuationNotStarted,
|
||||||
|
/// Received new continuation but it is already started
|
||||||
|
#[display(fmt = "Received new continuation but it is already started")]
|
||||||
|
ContinuationStarted,
|
||||||
|
/// Unknown continuation fragment
|
||||||
|
#[display(fmt = "Unknown continuation fragment.")]
|
||||||
|
ContinuationFragment(OpCode),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for ProtocolError {}
|
||||||
|
|
||||||
|
/// Websocket client error
|
||||||
|
#[derive(Debug, Display, From)]
|
||||||
|
pub enum WsClientBuilderError {
|
||||||
|
#[display(fmt = "Missing url scheme")]
|
||||||
|
MissingScheme,
|
||||||
|
#[display(fmt = "Unknown url scheme")]
|
||||||
|
UnknownScheme,
|
||||||
|
#[display(fmt = "Missing host name")]
|
||||||
|
MissingHost,
|
||||||
|
#[display(fmt = "Url parse error: {}", _0)]
|
||||||
|
Http(HttpError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for WsClientBuilderError {}
|
||||||
|
|
||||||
|
/// Websocket client error
|
||||||
|
#[derive(Debug, Display, From)]
|
||||||
|
pub enum WsClientError {
|
||||||
|
/// Invalid response
|
||||||
|
#[display(fmt = "Invalid response")]
|
||||||
|
InvalidResponse(ParseError),
|
||||||
|
/// Invalid response status
|
||||||
|
#[display(fmt = "Invalid response status")]
|
||||||
|
InvalidResponseStatus(StatusCode),
|
||||||
|
/// Invalid upgrade header
|
||||||
|
#[display(fmt = "Invalid upgrade header")]
|
||||||
|
InvalidUpgradeHeader,
|
||||||
|
/// Invalid connection header
|
||||||
|
#[display(fmt = "Invalid connection header")]
|
||||||
|
InvalidConnectionHeader(HeaderValue),
|
||||||
|
/// Missing CONNECTION header
|
||||||
|
#[display(fmt = "Missing CONNECTION header")]
|
||||||
|
MissingConnectionHeader,
|
||||||
|
/// Missing SEC-WEBSOCKET-ACCEPT header
|
||||||
|
#[display(fmt = "Missing SEC-WEBSOCKET-ACCEPT header")]
|
||||||
|
MissingWebSocketAcceptHeader,
|
||||||
|
/// Invalid challenge response
|
||||||
|
#[display(fmt = "Invalid challenge response")]
|
||||||
|
InvalidChallengeResponse(String, HeaderValue),
|
||||||
|
/// Protocol error
|
||||||
|
#[display(fmt = "{}", _0)]
|
||||||
|
Protocol(ProtocolError),
|
||||||
|
/// Response took too long
|
||||||
|
#[display(fmt = "Timeout out while waiting for response")]
|
||||||
|
Timeout,
|
||||||
|
/// Failed to connect to host
|
||||||
|
#[display(fmt = "Failed to connect to host: {}", _0)]
|
||||||
|
Connect(ConnectError),
|
||||||
|
/// Connector has been disconnected
|
||||||
|
#[display(fmt = "Connector has been disconnected: {:?}", _0)]
|
||||||
|
Disconnected(Option<io::Error>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl error::Error for WsClientError {}
|
||||||
|
|
||||||
|
impl From<Either<ParseError, io::Error>> for WsClientError {
|
||||||
|
fn from(err: Either<ParseError, io::Error>) -> Self {
|
||||||
|
match err {
|
||||||
|
Either::Left(err) => WsClientError::InvalidResponse(err),
|
||||||
|
Either::Right(err) => WsClientError::Disconnected(Some(err)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Either<io::Error, io::Error>> for WsClientError {
|
||||||
|
fn from(err: Either<io::Error, io::Error>) -> Self {
|
||||||
|
WsClientError::Disconnected(Some(err.into_inner()))
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,7 +4,7 @@ use log::debug;
|
||||||
use nanorand::{Rng, WyRand};
|
use nanorand::{Rng, WyRand};
|
||||||
|
|
||||||
use super::proto::{CloseCode, CloseReason, OpCode};
|
use super::proto::{CloseCode, CloseReason, OpCode};
|
||||||
use super::{mask::apply_mask, ProtocolError};
|
use super::{error::ProtocolError, mask::apply_mask};
|
||||||
use crate::util::{Buf, BufMut, BytesMut};
|
use crate::util::{Buf, BufMut, BytesMut};
|
||||||
|
|
||||||
/// WebSocket frame parser.
|
/// WebSocket frame parser.
|
||||||
|
|
|
@ -3,10 +3,7 @@
|
||||||
//! To setup a `WebSocket`, first do web socket handshake then on success
|
//! To setup a `WebSocket`, first do web socket handshake then on success
|
||||||
//! convert `Payload` into a `WsStream` stream and then use `WsWriter` to
|
//! convert `Payload` into a `WsStream` stream and then use `WsWriter` to
|
||||||
//! communicate with the peer.
|
//! communicate with the peer.
|
||||||
use std::io;
|
mod client;
|
||||||
|
|
||||||
use derive_more::{Display, From};
|
|
||||||
|
|
||||||
mod codec;
|
mod codec;
|
||||||
mod frame;
|
mod frame;
|
||||||
mod mask;
|
mod mask;
|
||||||
|
@ -14,55 +11,11 @@ mod proto;
|
||||||
mod sink;
|
mod sink;
|
||||||
mod stream;
|
mod stream;
|
||||||
|
|
||||||
|
pub mod error;
|
||||||
|
|
||||||
|
pub use self::client::{WsClient, WsClientBuilder, WsConnection};
|
||||||
pub use self::codec::{Codec, Frame, Item, Message};
|
pub use self::codec::{Codec, Frame, Item, Message};
|
||||||
pub use self::frame::Parser;
|
pub use self::frame::Parser;
|
||||||
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
|
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
|
||||||
pub use self::sink::WsSink;
|
pub use self::sink::WsSink;
|
||||||
pub use self::stream::{StreamDecoder, StreamEncoder};
|
pub use self::stream::{StreamDecoder, StreamEncoder};
|
||||||
|
|
||||||
/// Websocket service errors
|
|
||||||
#[derive(Debug, Display)]
|
|
||||||
pub enum WsError<E> {
|
|
||||||
Service(E),
|
|
||||||
KeepAlive,
|
|
||||||
Disconnected,
|
|
||||||
Protocol(ProtocolError),
|
|
||||||
Io(io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Websocket protocol errors
|
|
||||||
#[derive(Debug, Display, From)]
|
|
||||||
pub enum ProtocolError {
|
|
||||||
/// Received an unmasked frame from client
|
|
||||||
#[display(fmt = "Received an unmasked frame from client")]
|
|
||||||
UnmaskedFrame,
|
|
||||||
/// Received a masked frame from server
|
|
||||||
#[display(fmt = "Received a masked frame from server")]
|
|
||||||
MaskedFrame,
|
|
||||||
/// Encountered invalid opcode
|
|
||||||
#[display(fmt = "Invalid opcode: {}", _0)]
|
|
||||||
InvalidOpcode(u8),
|
|
||||||
/// Invalid control frame length
|
|
||||||
#[display(fmt = "Invalid control frame length: {}", _0)]
|
|
||||||
InvalidLength(usize),
|
|
||||||
/// Bad web socket op code
|
|
||||||
#[display(fmt = "Bad web socket op code")]
|
|
||||||
BadOpCode,
|
|
||||||
/// A payload reached size limit.
|
|
||||||
#[display(fmt = "A payload reached size limit.")]
|
|
||||||
Overflow,
|
|
||||||
/// Continuation is not started
|
|
||||||
#[display(fmt = "Continuation is not started.")]
|
|
||||||
ContinuationNotStarted,
|
|
||||||
/// Received new continuation but it is already started
|
|
||||||
#[display(fmt = "Received new continuation but it is already started")]
|
|
||||||
ContinuationStarted,
|
|
||||||
/// Unknown continuation fragment
|
|
||||||
#[display(fmt = "Unknown continuation fragment.")]
|
|
||||||
ContinuationFragment(OpCode),
|
|
||||||
/// IO Error
|
|
||||||
#[display(fmt = "IO Error: {:?}", _0)]
|
|
||||||
Io(io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for ProtocolError {}
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ impl WsSink {
|
||||||
pub fn send(
|
pub fn send(
|
||||||
&self,
|
&self,
|
||||||
item: ws::Message,
|
item: ws::Message,
|
||||||
) -> impl Future<Output = Result<(), ws::ProtocolError>> {
|
) -> impl Future<Output = Result<(), ws::error::ProtocolError>> {
|
||||||
let inner = self.0.clone();
|
let inner = self.0.clone();
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::{
|
||||||
cell::RefCell, fmt, marker::PhantomData, pin::Pin, rc::Rc, task::Context, task::Poll,
|
cell::RefCell, fmt, marker::PhantomData, pin::Pin, rc::Rc, task::Context, task::Poll,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{Codec, Frame, Message, ProtocolError};
|
use super::{error::ProtocolError, Codec, Frame, Message};
|
||||||
use crate::util::{Bytes, BytesMut};
|
use crate::util::{Bytes, BytesMut};
|
||||||
use crate::{codec::Decoder, codec::Encoder, Sink, Stream};
|
use crate::{codec::Decoder, codec::Encoder, Sink, Stream};
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ impl Service for WsService {
|
||||||
io.encode((res, body::BodySize::None).into(), &codec)
|
io.encode((res, body::BodySize::None).into(), &codec)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
Dispatcher::new(io.into_boxed(), ws::Codec::new(), service, Timer::default())
|
Dispatcher::new(io.seal(), ws::Codec::new(), service, Timer::default())
|
||||||
.await
|
.await
|
||||||
.map_err(|_| panic!())
|
.map_err(|_| panic!())
|
||||||
};
|
};
|
||||||
|
@ -96,7 +96,7 @@ async fn test_simple() {
|
||||||
let conn = srv.ws().await.unwrap();
|
let conn = srv.ws().await.unwrap();
|
||||||
assert_eq!(conn.response().status(), StatusCode::SWITCHING_PROTOCOLS);
|
assert_eq!(conn.response().status(), StatusCode::SWITCHING_PROTOCOLS);
|
||||||
|
|
||||||
let (_, io, codec) = conn.into_inner();
|
let (io, codec, _) = conn.into_inner();
|
||||||
io.send(&codec, ws::Message::Text(ByteString::from_static("text")))
|
io.send(&codec, ws::Message::Text(ByteString::from_static("text")))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
@ -41,7 +41,7 @@ async fn test_simple() {
|
||||||
|
|
||||||
// start websocket service
|
// start websocket service
|
||||||
Dispatcher::new(
|
Dispatcher::new(
|
||||||
io.into_boxed(),
|
io.seal(),
|
||||||
ws::Codec::default(),
|
ws::Codec::default(),
|
||||||
ws_service,
|
ws_service,
|
||||||
Default::default(),
|
Default::default(),
|
||||||
|
@ -53,7 +53,7 @@ async fn test_simple() {
|
||||||
});
|
});
|
||||||
|
|
||||||
// client service
|
// client service
|
||||||
let (_, io, codec) = srv.ws().await.unwrap().into_inner();
|
let (io, codec, _) = srv.ws().await.unwrap().into_inner();
|
||||||
io.send(&codec, ws::Message::Text(ByteString::from_static("text")))
|
io.send(&codec, ws::Message::Text(ByteString::from_static("text")))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
|
@ -37,7 +37,7 @@ async fn web_ws() {
|
||||||
});
|
});
|
||||||
|
|
||||||
// client service
|
// client service
|
||||||
let (_, io, codec) = srv.ws().await.unwrap().into_inner();
|
let (io, codec, _) = srv.ws().await.unwrap().into_inner();
|
||||||
io.send(&codec, ws::Message::Text(ByteString::from_static("text")))
|
io.send(&codec, ws::Message::Text(ByteString::from_static("text")))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue