Proper handling responses from ws web handler (#189)

* Proper handling responses from ws web handler
This commit is contained in:
Nikolay Kim 2023-03-15 17:19:49 +09:00 committed by GitHub
parent 3db65156ee
commit 9aeb50c847
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 229 additions and 105 deletions

View file

@ -970,6 +970,7 @@ impl BufMut for &mut [u8] {
fn _assert_trait_object(_b: &dyn BufMut) {}
#[cfg(test)]
#[allow(unused_allocation, warnings)]
mod tests {
use super::*;
use crate::{BytesMut, BytesVec};

View file

@ -1,5 +1,9 @@
# Changes
## [0.6.5] - 2023-03-15
* web: Proper handling responses from ws web handler
## [0.6.4] - 2023-03-11
* http: Add `ClientResponse::headers_mut()` method

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.6.4"
version = "0.6.5"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"
@ -76,7 +76,7 @@ pin-project-lite = "0.2"
regex = { version = "1.7.0", default-features = false, features = ["std"] }
sha-1 = "0.10"
serde = { version = "1.0", features=["derive"] }
socket2 = "0.4"
socket2 = "0.5"
thiserror = "1.0"
# http/web framework

View file

@ -217,6 +217,13 @@ impl ClientRequest {
self
}
#[inline]
/// Set connection type of the message
pub fn set_connection_type(mut self, ctype: ConnectionType) -> Self {
self.head.set_connection_type(ctype);
self
}
/// Force close connection instead of returning it back to connections pool.
/// This setting affect only http/1 connections.
#[inline]

View file

@ -55,7 +55,13 @@ impl Clone for Codec {
impl fmt::Debug for Codec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "h1::Codec({:?})", self.flags)
f.debug_struct("h1::Codec")
.field("version", &self.version)
.field("flags", &self.flags)
.field("ctype", &self.ctype)
.field("encoder", &self.encoder)
.field("decoder", &self.decoder)
.finish()
}
}
@ -113,6 +119,12 @@ impl Codec {
flags.insert(f);
self.flags.set(flags);
}
pub(super) fn unset_streaming(&self) {
let mut flags = self.flags.get();
flags.remove(Flags::STREAM);
self.flags.set(flags);
}
}
impl Decoder for Codec {

View file

@ -9,7 +9,7 @@ use crate::http;
use crate::http::body::{BodySize, MessageBody, ResponseBody};
use crate::http::config::{DispatcherConfig, OnRequest};
use crate::http::error::{DispatchError, ParseError, PayloadError, ResponseError};
use crate::http::message::CurrentIo;
use crate::http::message::{ConnectionType, CurrentIo};
use crate::http::request::Request;
use crate::http::response::Response;
@ -58,6 +58,11 @@ enum State<B> {
ReadPayload,
#[error("State::SendPayload")]
SendPayload { body: ResponseBody<B> },
#[error("State::SendPayloadAndStop")]
SendPayloadAndStop {
body: ResponseBody<B>,
boxed_io: Option<Box<(IoBoxed, Codec)>>,
},
#[error("State::Upgrade")]
Upgrade(Option<Request>),
#[error("State::StopIo")]
@ -187,23 +192,33 @@ where
match ready!(fut.poll(cx)) {
Ok(res) => {
let (msg, body) = res.into().into_parts();
let item = if let Some(item) = msg.head().take_io() {
let io = if let Some(item) = msg.head().take_io() {
item
} else {
log::trace!("Handler service consumed io, exit");
log::trace!("Handler service consumed io, stop");
return Poll::Ready(Ok(()));
};
let _ = item
io.1.set_ctype(ConnectionType::Close);
io.1.unset_streaming();
let result = io
.0
.encode(Message::Item((msg, body.size())), &item.1);
match body.size() {
BodySize::None | BodySize::Empty => {}
_ => {
log::error!("Stream responses are not supported for upgrade requests");
.encode(Message::Item((msg, body.size())), &io.1);
if result.is_ok() {
match body.size() {
BodySize::None | BodySize::Empty => {
*this.st = State::StopIo(io)
}
_ => {
*this.st = State::SendPayloadAndStop {
body,
boxed_io: Some(io),
}
}
}
} else {
*this.st = State::StopIo(io);
}
*this.st = State::StopIo(item);
}
Err(e) => {
log::error!(
@ -312,6 +327,51 @@ where
}
}
}
// send response body
State::SendPayloadAndStop {
ref mut body,
ref mut boxed_io,
} => {
let io = boxed_io.as_ref().unwrap();
if io.0.is_closed() {
*this.st = State::Stop;
} else {
if let Poll::Ready(Err(err)) =
_poll_request_payload(&io.0, &mut this.inner.payload, cx)
{
this.inner.error = Some(err);
}
loop {
let _ = ready!(io.0.poll_flush(cx, false));
let item = ready!(body.poll_next_chunk(cx));
match item {
Some(Ok(item)) => {
trace!("got response chunk: {:?}", item.len());
if let Err(e) =
io.0.encode(Message::Chunk(Some(item)), &io.1)
{
trace!("Cannot encode chunk: {:?}", e);
} else {
continue;
}
}
None => {
trace!("response payload eof {:?}", this.inner.flags);
if let Err(e) = io.0.encode(Message::Chunk(None), &io.1)
{
trace!("Cannot encode payload eof: {:?}", e);
}
}
Some(Err(e)) => {
trace!("error during response body poll: {:?}", e);
}
}
*this.st = State::StopIo(boxed_io.take().unwrap());
break;
}
}
}
// read first request and call service
State::ReadFirstRequest => {
*this.st = ready!(this.inner.read_request(cx, &mut this.call));
@ -647,89 +707,7 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), DispatchError>> {
// check if payload data is required
let payload = if let Some(ref mut payload) = self.payload {
payload
} else {
return Poll::Ready(Ok(()));
};
match payload.1.poll_data_required(cx) {
PayloadStatus::Read => {
let io = &self.io;
// read request payload
let mut updated = false;
loop {
match io.poll_recv(&payload.0, cx) {
Poll::Ready(Ok(PayloadItem::Chunk(chunk))) => {
updated = true;
payload.1.feed_data(chunk);
}
Poll::Ready(Ok(PayloadItem::Eof)) => {
updated = true;
payload.1.feed_eof();
self.payload = None;
break;
}
Poll::Ready(Err(err)) => {
let err = match err {
RecvError::WriteBackpressure => {
if io.poll_flush(cx, false)?.is_pending() {
break;
} else {
continue;
}
}
RecvError::KeepAlive => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
io::Error::new(io::ErrorKind::Other, "Keep-alive")
.into()
}
RecvError::Stop => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
io::Error::new(
io::ErrorKind::Other,
"Dispatcher stopped",
)
.into()
}
RecvError::PeerGone(err) => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
if let Some(err) = err {
DispatchError::PeerGone(Some(err))
} else {
ParseError::Incomplete.into()
}
}
RecvError::Decoder(e) => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
DispatchError::Parse(e)
}
};
return Poll::Ready(Err(err));
}
Poll::Pending => break,
}
}
if updated {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
PayloadStatus::Pause => Poll::Pending,
PayloadStatus::Dropped => {
// service call is not interested in payload
// wait until future completes and then close
// connection
self.payload = None;
Poll::Ready(Err(DispatchError::PayloadIsNotConsumed))
}
}
_poll_request_payload(&self.io, &mut self.payload, cx)
}
/// check for io changes, could close while waiting for service call
@ -746,6 +724,91 @@ where
}
}
/// Process request's payload
fn _poll_request_payload<F>(
io: &Io<F>,
slf_payload: &mut Option<(PayloadDecoder, PayloadSender)>,
cx: &mut Context<'_>,
) -> Poll<Result<(), DispatchError>> {
// check if payload data is required
let payload = if let Some(ref mut payload) = slf_payload {
payload
} else {
return Poll::Ready(Ok(()));
};
match payload.1.poll_data_required(cx) {
PayloadStatus::Read => {
// read request payload
let mut updated = false;
loop {
match io.poll_recv(&payload.0, cx) {
Poll::Ready(Ok(PayloadItem::Chunk(chunk))) => {
updated = true;
payload.1.feed_data(chunk);
}
Poll::Ready(Ok(PayloadItem::Eof)) => {
updated = true;
payload.1.feed_eof();
*slf_payload = None;
break;
}
Poll::Ready(Err(err)) => {
let err = match err {
RecvError::WriteBackpressure => {
if io.poll_flush(cx, false)?.is_pending() {
break;
} else {
continue;
}
}
RecvError::KeepAlive => {
payload.1.set_error(PayloadError::EncodingCorrupted);
*slf_payload = None;
io::Error::new(io::ErrorKind::Other, "Keep-alive").into()
}
RecvError::Stop => {
payload.1.set_error(PayloadError::EncodingCorrupted);
*slf_payload = None;
io::Error::new(io::ErrorKind::Other, "Dispatcher stopped")
.into()
}
RecvError::PeerGone(err) => {
payload.1.set_error(PayloadError::EncodingCorrupted);
*slf_payload = None;
if let Some(err) = err {
DispatchError::PeerGone(Some(err))
} else {
ParseError::Incomplete.into()
}
}
RecvError::Decoder(e) => {
payload.1.set_error(PayloadError::EncodingCorrupted);
*slf_payload = None;
DispatchError::Parse(e)
}
};
return Poll::Ready(Err(err));
}
Poll::Pending => break,
}
}
if updated {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
PayloadStatus::Pause => Poll::Pending,
PayloadStatus::Dropped => {
// service call is not interested in payload
// wait until future completes and then close
// connection
*slf_payload = None;
Poll::Ready(Err(DispatchError::PayloadIsNotConsumed))
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};

View file

@ -452,7 +452,7 @@ impl ResponseBuilder {
/// .finish()
/// }
/// ```
pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self {
pub fn cookie(&mut self, cookie: Cookie<'_>) -> &mut Self {
if self.cookies.is_none() {
let mut jar = CookieJar::new();
jar.add(cookie.into_owned());
@ -479,7 +479,7 @@ impl ResponseBuilder {
/// builder.finish()
/// }
/// ```
pub fn del_cookie<'c>(&mut self, cookie: &Cookie<'c>) -> &mut Self {
pub fn del_cookie(&mut self, cookie: &Cookie<'_>) -> &mut Self {
if self.cookies.is_none() {
self.cookies = Some(CookieJar::new())
}

View file

@ -115,7 +115,7 @@ impl TestRequest {
#[cfg(feature = "cookie")]
/// Set cookie for this request
pub fn cookie<'a>(&mut self, cookie: Cookie<'a>) -> &mut Self {
pub fn cookie(&mut self, cookie: Cookie<'_>) -> &mut Self {
parts(&mut self.0).cookies.add(cookie.into_owned());
self
}

View file

@ -11,15 +11,12 @@ use ntex::http::header::{
ContentEncoding, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TYPE,
TRANSFER_ENCODING,
};
use ntex::http::{body::Body, client};
use ntex::http::{Method, StatusCode};
use ntex::http::{body::Body, client, ConnectionType, Method, StatusCode};
use ntex::time::{sleep, Millis, Seconds, Sleep};
use ntex::util::{ready, Bytes, Ready, Stream};
use ntex::web::middleware::Compress;
use ntex::web::{
self, test, App, BodyEncoding, HttpRequest, HttpResponse, WebResponseError,
};
use ntex::web::{self, middleware::Compress, test};
use ntex::web::{App, BodyEncoding, HttpRequest, HttpResponse, WebResponseError};
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
Hello World Hello World Hello World Hello World Hello World \
@ -1199,3 +1196,43 @@ async fn test_web_server() {
system.stop();
}
#[ntex::test]
async fn web_no_ws_payload() {
let srv = test::server_with(test::config().h1(), || {
App::new()
.service(web::resource("/").route(web::get().to(move || async {
HttpResponse::Ok()
.streaming(TestBody::new(Bytes::from_static(STR.as_ref()), 24))
})))
.service(
web::resource("/f")
.route(web::get().to(move || async { HttpResponse::Ok().body(STR) })),
)
});
let client = client::Client::build().timeout(Seconds(30)).finish();
let mut response = client
.request(Method::GET, format!("http://{:?}/f", srv.addr()))
.header("sec-websocket-version", "13")
.header("upgrade", "websocket")
.header("sec-websocket-key", "ld75/p3D5ju5UhWsNMcJHA==")
.set_connection_type(ConnectionType::Upgrade)
.send()
.await
.unwrap();
let body = response.body().await.unwrap();
assert_eq!(body, STR);
let mut response = client
.request(Method::GET, format!("http://{:?}/", srv.addr()))
.header("sec-websocket-version", "13")
.header("upgrade", "websocket")
.header("sec-websocket-key", "ld75/p3D5ju5UhWsNMcJHA==")
.set_connection_type(ConnectionType::Upgrade)
.send()
.await
.unwrap();
let body = response.body().await.unwrap();
assert_eq!(body, STR);
}