http: Pass io stream to upgrade handler

This commit is contained in:
Nikolay Kim 2021-01-24 22:23:19 +06:00
parent c47ec4ae25
commit f0fe2bbc59
13 changed files with 667 additions and 164 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.2.0-b.6] - 2021-01-24
* http: Pass io stream to upgrade handler
## [0.2.0-b.5] - 2021-01-23
* accept shared ref in some methods of framed::State type

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.2.0-b.5"
version = "0.2.0-b.6"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"

View file

@ -39,6 +39,9 @@ where
if self.state.is_io_shutdown() {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
} else if self.state.is_io_stop() {
self.state.dsp_wake_task();
Poll::Ready(())
} else if self.state.is_read_paused() {
self.state.register_read_task(cx.waker());
Poll::Pending

View file

@ -13,22 +13,26 @@ use crate::task::LocalWaker;
const HW: usize = 8 * 1024;
bitflags::bitflags! {
pub struct Flags: u8 {
pub struct Flags: u16 {
const DSP_STOP = 0b0000_0001;
const DSP_KEEPALIVE = 0b0000_0010;
/// io error occured
const IO_ERR = 0b0000_0100;
const IO_SHUTDOWN = 0b0000_1000;
/// stop io tasks
const IO_STOP = 0b0000_1000;
/// shutdown io tasks
const IO_SHUTDOWN = 0b0001_0000;
/// pause io read
const RD_PAUSED = 0b0001_0000;
const RD_PAUSED = 0b0010_0000;
/// new data is available
const RD_READY = 0b0010_0000;
const RD_READY = 0b0100_0000;
/// write buffer is full
const WR_NOT_READY = 0b0100_0000;
const WR_NOT_READY = 0b1000_0000;
const ST_DSP_ERR = 0b1000_0000;
const ST_DSP_ERR = 0b10000_0000;
}
}
@ -148,6 +152,11 @@ impl State {
.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN)
}
#[inline]
pub fn is_io_stop(&self) -> bool {
self.0.flags.get().contains(Flags::IO_STOP)
}
#[inline]
/// Check if read buffer has new data
pub fn is_read_ready(&self) -> bool {
@ -317,6 +326,17 @@ impl State {
self.0.dispatch_task.register(waker);
}
#[inline]
/// Stop io tasks
pub fn dsp_stop_io(&self, waker: &Waker) {
let mut flags = self.0.flags.get();
flags.insert(Flags::IO_STOP);
self.0.flags.set(flags);
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.register(waker);
}
#[inline]
/// Wake dispatcher
pub fn dsp_wake_task(&self) {
@ -329,6 +349,14 @@ impl State {
self.0.dispatch_task.register(waker);
}
#[inline]
/// Reset io stop flags
pub fn reset_io_stop(&self) {
let mut flags = self.0.flags.get();
flags.remove(Flags::IO_STOP);
self.0.flags.set(flags);
}
fn mark_io_error(&self) {
self.0.read_task.wake();
self.0.write_task.wake();

View file

@ -74,6 +74,9 @@ where
if this.state.is_io_err() {
log::trace!("write io is closed");
return Poll::Ready(());
} else if this.state.is_io_stop() {
self.state.dsp_wake_task();
return Poll::Ready(());
}
match this.st {
@ -224,7 +227,7 @@ where
}
}
}
// log::trace!("flushed {} bytes", written);
log::trace!("flushed {} bytes", written);
// remove written data
if written == len {

View file

@ -55,7 +55,7 @@ where
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<Config = (), Request = (Request, State, Codec), Response = ()>,
U: ServiceFactory<Config = (), Request = (Request, T, State, Codec), Response = ()>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
@ -142,7 +142,7 @@ where
F: IntoServiceFactory<U1>,
U1: ServiceFactory<
Config = (),
Request = (Request, State, Codec),
Request = (Request, T, State, Codec),
Response = (),
>,
U1::Error: fmt::Display + Error + 'static,

View file

@ -1,21 +1,20 @@
//! Framed transport dispatcher
use std::error::Error;
use std::task::{Context, Poll};
use std::{
cell::RefCell, fmt, future::Future, marker::PhantomData, net, pin::Pin, rc::Rc,
time::Duration, time::Instant,
cell::RefCell, error::Error, fmt, marker::PhantomData, net, pin::Pin, rc::Rc, time,
};
use bytes::Bytes;
use futures::Future;
use crate::codec::{AsyncRead, AsyncWrite, Decoder};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::framed::{ReadTask, State as IoState, WriteTask};
use crate::service::Service;
use crate::http;
use crate::http::body::{BodySize, MessageBody, ResponseBody};
use crate::http::config::DispatcherConfig;
use crate::http::error::{DispatchError, PayloadError, ResponseError};
use crate::http::error::{DispatchError, ParseError, PayloadError, ResponseError};
use crate::http::helpers::DataFactory;
use crate::http::request::Request;
use crate::http::response::Response;
@ -37,11 +36,11 @@ bitflags::bitflags! {
pin_project_lite::pin_project! {
/// Dispatcher for HTTP/1.1 protocol
pub struct Dispatcher<S: Service, B, X: Service, U: Service> {
pub struct Dispatcher<T, S: Service, B, X: Service, U: Service> {
#[pin]
call: CallState<S, X, U>,
st: State<B>,
inner: DispatcherInner<S, B, X, U>,
inner: DispatcherInner<T, S, B, X, U>,
}
}
@ -50,6 +49,7 @@ enum State<B> {
ReadRequest,
ReadPayload,
SendPayload { body: ResponseBody<B> },
Upgrade(Option<Request>),
Stop,
}
@ -63,12 +63,13 @@ pin_project_lite::pin_project! {
}
}
struct DispatcherInner<S, B, X, U> {
struct DispatcherInner<T, S, B, X, U> {
io: Option<Rc<RefCell<T>>>,
flags: Flags,
codec: Codec,
config: Rc<DispatcherConfig<S, X, U>>,
state: IoState,
expire: Instant,
expire: time::Instant,
error: Option<DispatchError>,
payload: Option<(PayloadDecoder, PayloadSender)>,
peer_addr: Option<net::SocketAddr>,
@ -77,34 +78,38 @@ struct DispatcherInner<S, B, X, U> {
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum PollPayloadStatus {
enum ReadPayloadStatus {
Done,
Updated,
Pending,
Dropped,
}
impl<S, B, X, U> Dispatcher<S, B, X, U>
enum WritePayloadStatus<B> {
Next(State<B>),
Pause,
Continue,
}
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
U: Service<Request = (Request, IoState, Codec), Response = ()>,
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
U::Error: Error + fmt::Display,
{
/// Construct new `Dispatcher` instance with outgoing messages stream.
pub(in crate::http) fn new<T>(
pub(in crate::http) fn new(
io: T,
config: Rc<DispatcherConfig<S, X, U>>,
peer_addr: Option<net::SocketAddr>,
on_connect_data: Option<Box<dyn DataFactory>>,
) -> Self
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
) -> Self {
let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled());
let state = IoState::new();
@ -115,18 +120,19 @@ where
// slow-request timer
if config.client_timeout != 0 {
expire += Duration::from_secs(config.client_timeout);
expire += time::Duration::from_secs(config.client_timeout);
config.timer_h1.register(expire, expire, &state);
}
// start support io tasks
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
crate::rt::spawn(WriteTask::new(io, state.clone()));
crate::rt::spawn(WriteTask::new(io.clone(), state.clone()));
Dispatcher {
call: CallState::None,
st: State::ReadRequest,
inner: DispatcherInner {
io: Some(io),
flags: Flags::empty(),
error: None,
payload: None,
@ -142,15 +148,16 @@ where
}
}
impl<S, B, X, U> Future for Dispatcher<S, B, X, U>
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, IoState, Codec), Response = ()>,
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
U::Error: Error + fmt::Display + 'static,
{
type Output = Result<(), DispatchError>;
@ -180,7 +187,7 @@ where
// we might need to read more data into a request payload
// (ie service future can wait for payload data)
if this.inner.poll_read_payload(cx)
!= PollPayloadStatus::Updated
!= ReadPayloadStatus::Updated
{
return Poll::Pending;
}
@ -197,26 +204,15 @@ where
b"HTTP/1.1 100 Continue\r\n\r\n",
)
});
Some(if this.inner.flags.contains(Flags::UPGRADE) {
// Handle UPGRADE request
CallState::Upgrade {
fut: this
.inner
.config
.upgrade
.as_ref()
.unwrap()
.call((
req,
this.inner.state.clone(),
this.inner.codec.clone(),
)),
}
if this.inner.flags.contains(Flags::UPGRADE) {
this.inner.state.dsp_stop_io(cx.waker());
*this.st = State::Upgrade(Some(req));
return Poll::Pending;
} else {
CallState::Service {
Some(CallState::Service {
fut: this.inner.config.service.call(req),
}
})
})
}
}
Err(e) => {
*this.st = this.inner.handle_error(e, true);
@ -273,7 +269,11 @@ where
if this.inner.state.is_read_ready() {
match this.inner.state.decode_item(&this.inner.codec) {
Ok(Some((mut req, pl))) => {
log::trace!("http message is received: {:?}", req);
log::trace!(
"http message is received: {:?} and payload {:?}",
req,
pl
);
req.head_mut().peer_addr = this.inner.peer_addr;
// configure request payload
@ -313,48 +313,46 @@ where
on_connect.set(&mut req.extensions_mut());
}
// call service
*this.st = State::Call;
this.call.set(if req.head().expect() {
if req.head().expect() {
// call service
*this.st = State::Call;
// Handle `EXPECT: 100-Continue` header
CallState::Expect {
this.call.set(CallState::Expect {
fut: this.inner.config.expect.call(req),
}
});
} else if upgrade {
log::trace!("initate upgrade handling");
log::trace!("prep io for upgrade handler");
// Handle UPGRADE request
CallState::Upgrade {
fut: this
.inner
.config
.upgrade
.as_ref()
.unwrap()
.call((
req,
this.inner.state.clone(),
this.inner.codec.clone(),
)),
}
this.inner.state.dsp_stop_io(cx.waker());
*this.st = State::Upgrade(Some(req));
return Poll::Pending;
} else {
// Handle normal requests
CallState::Service {
*this.st = State::Call;
this.call.set(CallState::Service {
fut: this.inner.config.service.call(req),
}
});
});
}
}
Ok(None) => {
// if connection is not keep-alive then disconnect
log::trace!("not enough data to decode next frame, register dispatch task");
// if io error occured or connection is not keep-alive
// then disconnect
if this.inner.flags.contains(Flags::STARTED)
&& !this.inner.flags.contains(Flags::KEEPALIVE)
&& (!this.inner.flags.contains(Flags::KEEPALIVE)
|| !this.inner.codec.keepalive_enabled()
|| this.inner.state.is_io_err())
{
*this.st = State::Stop;
this.inner.state.dsp_mark_stopped();
continue;
}
this.inner.state.dsp_read_more_data(cx.waker());
return Poll::Pending;
}
Err(err) => {
log::trace!("malformed request: {:?}", err);
// Malformed requests, respond with 400
let (res, body) =
Response::BadRequest().finish().into_parts();
@ -379,31 +377,70 @@ where
// consume request's payload
State::ReadPayload => loop {
match this.inner.poll_read_payload(cx) {
PollPayloadStatus::Updated => continue,
PollPayloadStatus::Pending => return Poll::Pending,
PollPayloadStatus::Done => {
ReadPayloadStatus::Updated => continue,
ReadPayloadStatus::Pending => return Poll::Pending,
ReadPayloadStatus::Done => {
*this.st = {
this.inner.reset_keepalive();
State::ReadRequest
}
}
PollPayloadStatus::Dropped => *this.st = State::Stop,
ReadPayloadStatus::Dropped => *this.st = State::Stop,
}
break;
},
// send response body
State::SendPayload { ref mut body } => {
this.inner.poll_read_payload(cx);
if this.inner.state.is_io_err() {
*this.st = State::Stop;
} else {
this.inner.poll_read_payload(cx);
match body.poll_next_chunk(cx) {
Poll::Ready(item) => {
if let Some(st) = this.inner.send_payload(item) {
*this.st = st;
}
match body.poll_next_chunk(cx) {
Poll::Ready(item) => match this.inner.send_payload(item) {
WritePayloadStatus::Next(st) => {
*this.st = st;
}
WritePayloadStatus::Pause => {
this.inner.state.dsp_flush_write_data(cx.waker());
return Poll::Pending;
}
WritePayloadStatus::Continue => (),
},
Poll::Pending => return Poll::Pending,
}
Poll::Pending => return Poll::Pending,
}
}
// stop io tasks and call upgrade service
State::Upgrade(ref mut req) => {
// check if all io tasks have been stopped
let io = if Rc::strong_count(this.inner.io.as_ref().unwrap()) == 1 {
if let Ok(io) = Rc::try_unwrap(this.inner.io.take().unwrap()) {
io.into_inner()
} else {
return Poll::Ready(Err(DispatchError::InternalError));
}
} else {
// wait next task stop
this.inner.state.dsp_register_task(cx.waker());
return Poll::Pending;
};
log::trace!("initate upgrade handling");
let req = req.take().unwrap();
*this.st = State::Call;
this.inner.state.reset_io_stop();
// Handle UPGRADE request
this.call.set(CallState::Upgrade {
fut: this.inner.config.upgrade.as_ref().unwrap().call((
req,
io,
this.inner.state.clone(),
this.inner.codec.clone(),
)),
});
}
// prepare to shutdown
State::Stop => {
this.inner.state.shutdown_io();
@ -426,7 +463,7 @@ where
}
}
impl<S, B, X, U> DispatcherInner<S, B, X, U>
impl<T, S, B, X, U> DispatcherInner<T, S, B, X, U>
where
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
@ -442,8 +479,8 @@ where
fn reset_keepalive(&mut self) {
// re-register keep-alive
if self.flags.contains(Flags::KEEPALIVE) {
let expire =
self.config.timer_h1.now() + Duration::from_secs(self.config.keep_alive);
let expire = self.config.timer_h1.now()
+ time::Duration::from_secs(self.config.keep_alive);
self.config
.timer_h1
.register(expire, self.expire, &self.state);
@ -512,18 +549,25 @@ where
fn send_payload(
&mut self,
item: Option<Result<Bytes, Box<dyn Error>>>,
) -> Option<State<B>> {
) -> WritePayloadStatus<B> {
match item {
Some(Ok(item)) => {
trace!("Got response chunk: {:?}", item.len());
if let Err(err) = self
match self
.state
.write_item(Message::Chunk(Some(item)), &self.codec)
{
self.error = Some(DispatchError::Encode(err));
Some(State::Stop)
} else {
None
Err(err) => {
self.error = Some(DispatchError::Encode(err));
WritePayloadStatus::Next(State::Stop)
}
Ok(has_space) => {
if has_space {
WritePayloadStatus::Continue
} else {
WritePayloadStatus::Pause
}
}
}
}
None => {
@ -532,24 +576,24 @@ where
self.state.write_item(Message::Chunk(None), &self.codec)
{
self.error = Some(DispatchError::Encode(err));
Some(State::Stop)
WritePayloadStatus::Next(State::Stop)
} else if self.payload.is_some() {
Some(State::ReadPayload)
WritePayloadStatus::Next(State::ReadPayload)
} else {
self.reset_keepalive();
Some(State::ReadRequest)
WritePayloadStatus::Next(State::ReadRequest)
}
}
Some(Err(e)) => {
trace!("Error during response body poll: {:?}", e);
self.error = Some(DispatchError::ResponsePayload(e));
Some(State::Stop)
WritePayloadStatus::Next(State::Stop)
}
}
}
/// Process request's payload
fn poll_read_payload(&mut self, cx: &mut Context<'_>) -> PollPayloadStatus {
fn poll_read_payload(&mut self, cx: &mut Context<'_>) -> ReadPayloadStatus {
// check if payload data is required
if let Some(ref mut payload) = self.payload {
match payload.1.poll_data_required(cx) {
@ -557,7 +601,7 @@ where
// read request payload
let mut updated = false;
loop {
let item = self.state.with_read_buf(|buf| payload.0.decode(buf));
let item = self.state.decode_item(&payload.0);
match item {
Ok(Some(PayloadItem::Chunk(chunk))) => {
updated = true;
@ -567,40 +611,434 @@ where
payload.1.feed_eof();
self.payload = None;
if !updated {
return PollPayloadStatus::Done;
return ReadPayloadStatus::Done;
}
break;
}
Ok(None) => {
self.state.dsp_read_more_data(cx.waker());
break;
if self.state.is_io_err() {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
self.error = Some(ParseError::Incomplete.into());
return ReadPayloadStatus::Dropped;
} else {
self.state.dsp_read_more_data(cx.waker());
break;
}
}
Err(e) => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
self.error = Some(DispatchError::Parse(e));
return PollPayloadStatus::Dropped;
return ReadPayloadStatus::Dropped;
}
}
}
if updated {
PollPayloadStatus::Updated
ReadPayloadStatus::Updated
} else {
PollPayloadStatus::Pending
ReadPayloadStatus::Pending
}
}
PayloadStatus::Pause => PollPayloadStatus::Pending,
PayloadStatus::Pause => ReadPayloadStatus::Pending,
PayloadStatus::Dropped => {
// service call is not interested in payload
// wait until future completes and then close
// connection
self.payload = None;
self.error = Some(DispatchError::PayloadIsNotConsumed);
PollPayloadStatus::Dropped
ReadPayloadStatus::Dropped
}
}
} else {
PollPayloadStatus::Done
ReadPayloadStatus::Done
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::{io, sync::Arc};
use bytes::{Bytes, BytesMut};
use futures::future::{lazy, ok, FutureExt};
use futures::StreamExt;
use rand::Rng;
use super::*;
use crate::codec::Decoder;
use crate::http::config::{DispatcherConfig, ServiceConfig};
use crate::http::h1::{ClientCodec, ExpectHandler, UpgradeHandler};
use crate::http::{body, Request, ResponseHead, StatusCode};
use crate::rt::time::delay_for;
use crate::service::IntoService;
use crate::testing::Io;
const BUFFER_SIZE: usize = 32_768;
/// Create http/1 dispatcher.
pub(crate) fn h1<F, S, B>(
stream: Io,
service: F,
) -> Dispatcher<Io, S, B, ExpectHandler, UpgradeHandler<Io>>
where
F: IntoService<S>,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
{
Dispatcher::new(
stream,
Rc::new(DispatcherConfig::new(
ServiceConfig::default(),
service.into_service(),
ExpectHandler,
None,
)),
None,
None,
)
}
pub(crate) fn spawn_h1<F, S, B>(stream: Io, service: F)
where
F: IntoService<S>,
S: Service<Request = Request> + 'static,
S::Error: ResponseError,
S::Response: Into<Response<B>>,
B: MessageBody + 'static,
{
crate::rt::spawn(
Dispatcher::<Io, S, B, ExpectHandler, UpgradeHandler<Io>>::new(
stream,
Rc::new(DispatcherConfig::new(
ServiceConfig::default(),
service.into_service(),
ExpectHandler,
None,
)),
None,
None,
),
);
}
fn load(decoder: &mut ClientCodec, buf: &mut BytesMut) -> ResponseHead {
decoder.decode(buf).unwrap().unwrap()
}
#[ntex_rt::test]
async fn test_req_parse_err() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish()));
delay_for(time::Duration::from_millis(50)).await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
assert!(!h1.inner.state.is_open());
delay_for(time::Duration::from_millis(50)).await;
client
.local_buffer(|buf| assert_eq!(&buf[..26], b"HTTP/1.1 400 Bad Request\r\n"));
client.close().await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
// assert!(h1.inner.flags.contains(Flags::SHUTDOWN_IO));
assert!(h1.inner.state.is_io_err());
}
#[ntex_rt::test]
async fn test_pipeline() {
let (client, server) = Io::create();
client.remote_buffer_cap(4096);
let mut decoder = ClientCodec::default();
spawn_h1(server, |_| ok::<_, io::Error>(Response::Ok().finish()));
client.write("GET /test HTTP/1.1\r\n\r\n");
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(!client.is_server_dropped());
client.write("GET /test HTTP/1.1\r\n\r\n");
client.write("GET /test HTTP/1.1\r\n\r\n");
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(decoder.decode(&mut buf).unwrap().is_none());
assert!(!client.is_server_dropped());
client.close().await;
assert!(client.is_server_dropped());
}
#[ntex_rt::test]
async fn test_pipeline_with_payload() {
let (client, server) = Io::create();
client.remote_buffer_cap(4096);
let mut decoder = ClientCodec::default();
spawn_h1(server, |mut req: Request| async move {
let mut p = req.take_payload();
while let Some(_) = p.next().await {}
Ok::<_, io::Error>(Response::Ok().finish())
});
client.write("GET /test HTTP/1.1\r\ncontent-length: 5\r\n\r\n");
delay_for(time::Duration::from_millis(50)).await;
client.write("xxxxx");
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(!client.is_server_dropped());
client.write("GET /test HTTP/1.1\r\n\r\n");
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(decoder.decode(&mut buf).unwrap().is_none());
assert!(!client.is_server_dropped());
client.close().await;
assert!(client.is_server_dropped());
}
#[ntex_rt::test]
async fn test_pipeline_with_delay() {
let (client, server) = Io::create();
client.remote_buffer_cap(4096);
let mut decoder = ClientCodec::default();
spawn_h1(server, |_| async {
delay_for(time::Duration::from_millis(100)).await;
Ok::<_, io::Error>(Response::Ok().finish())
});
client.write("GET /test HTTP/1.1\r\n\r\n");
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(!client.is_server_dropped());
client.write("GET /test HTTP/1.1\r\n\r\n");
client.write("GET /test HTTP/1.1\r\n\r\n");
delay_for(time::Duration::from_millis(50)).await;
client.write("GET /test HTTP/1.1\r\n\r\n");
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(decoder.decode(&mut buf).unwrap().is_none());
assert!(!client.is_server_dropped());
buf.extend(client.read().await.unwrap());
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(decoder.decode(&mut buf).unwrap().is_none());
assert!(!client.is_server_dropped());
client.close().await;
assert!(client.is_server_dropped());
}
#[ntex_rt::test]
/// if socket is disconnected, h1 dispatcher does not process any data
// /// h1 dispatcher still processes all incoming requests
// /// but it does not write any data to socket
async fn test_write_disconnected() {
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
let (client, server) = Io::create();
spawn_h1(server, move |_| {
num2.fetch_add(1, Ordering::Relaxed);
ok::<_, io::Error>(Response::Ok().finish())
});
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1.1\r\n\r\n");
client.write("GET /test HTTP/1.1\r\n\r\n");
client.write("GET /test HTTP/1.1\r\n\r\n");
client.close().await;
assert!(client.is_server_dropped());
assert!(client.read_any().is_empty());
// only first request get handled
assert_eq!(num.load(Ordering::Relaxed), 0);
}
#[ntex_rt::test]
async fn test_read_large_message() {
let (client, server) = Io::create();
client.remote_buffer_cap(4096);
let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish()));
let mut decoder = ClientCodec::default();
let data = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(70_000)
.map(char::from)
.collect::<String>();
client.write("GET /test HTTP/1.1\r\nContent-Length: ");
client.write(data);
delay_for(time::Duration::from_millis(50)).await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
delay_for(time::Duration::from_millis(50)).await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
assert!(!h1.inner.state.is_open());
let mut buf = client.read().await.unwrap();
assert_eq!(load(&mut decoder, &mut buf).status, StatusCode::BAD_REQUEST);
}
#[ntex_rt::test]
async fn test_read_backpressure() {
let mark = Arc::new(AtomicBool::new(false));
let mark2 = mark.clone();
let (client, server) = Io::create();
client.remote_buffer_cap(4096);
spawn_h1(server, move |mut req: Request| {
let m = mark2.clone();
async move {
// read one chunk
let mut pl = req.take_payload();
let _ = pl.next().await.unwrap().unwrap();
m.store(true, Ordering::Relaxed);
// sleep
delay_for(time::Duration::from_secs(999_999)).await;
Ok::<_, io::Error>(Response::Ok().finish())
}
});
client.write("GET /test HTTP/1.1\r\nContent-Length: 1048576\r\n\r\n");
delay_for(time::Duration::from_millis(50)).await;
// buf must be consumed
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
// io should be drained only by no more than MAX_BUFFER_SIZE
let random_bytes: Vec<u8> =
(0..1_048_576).map(|_| rand::random::<u8>()).collect();
client.write(random_bytes);
delay_for(time::Duration::from_millis(50)).await;
assert!(client.remote_buffer(|buf| buf.len()) > 1_048_576 - BUFFER_SIZE * 3);
assert!(mark.load(Ordering::Relaxed));
}
#[ntex_rt::test]
async fn test_write_backpressure() {
std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace");
env_logger::init();
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
struct Stream(Arc<AtomicUsize>);
impl body::MessageBody for Stream {
fn size(&self) -> body::BodySize {
body::BodySize::Stream
}
fn poll_next_chunk(
&mut self,
_: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Box<dyn std::error::Error>>>> {
let data = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(65_536)
.map(char::from)
.collect::<String>();
self.0.fetch_add(data.len(), Ordering::Relaxed);
Poll::Ready(Some(Ok(Bytes::from(data))))
}
}
let (client, server) = Io::create();
let mut h1 = h1(server, move |_| {
let n = num2.clone();
async move { Ok::<_, io::Error>(Response::Ok().message_body(Stream(n.clone()))) }
.boxed_local()
});
let state = h1.inner.state.clone();
// do not allow to write to socket
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1.1\r\n\r\n");
delay_for(time::Duration::from_millis(50)).await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
// buf must be consumed
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
// amount of generated data
assert_eq!(num.load(Ordering::Relaxed), 65_536);
// response message + chunking encoding
assert_eq!(state.with_write_buf(|buf| buf.len()), 65629);
client.remote_buffer_cap(65536);
delay_for(time::Duration::from_millis(50)).await;
assert_eq!(state.with_write_buf(|buf| buf.len()), 93);
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
assert_eq!(num.load(Ordering::Relaxed), 65_536 * 2);
}
#[ntex_rt::test]
async fn test_disconnect_during_response_body_pending() {
struct Stream(bool);
impl body::MessageBody for Stream {
fn size(&self) -> body::BodySize {
body::BodySize::Sized(2048)
}
fn poll_next_chunk(
&mut self,
_: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Box<dyn std::error::Error>>>> {
if self.0 {
Poll::Pending
} else {
self.0 = true;
let data = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(1024)
.map(char::from)
.collect::<String>();
Poll::Ready(Some(Ok(Bytes::from(data))))
}
}
}
let (client, server) = Io::create();
client.remote_buffer_cap(4096);
let mut h1 = h1(server, |_| {
ok::<_, io::Error>(Response::Ok().message_body(Stream(false)))
});
client.write("GET /test HTTP/1.1\r\n\r\n");
delay_for(time::Duration::from_millis(50)).await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
// http message must be consumed
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
let mut decoder = ClientCodec::default();
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
client.close().await;
delay_for(time::Duration::from_millis(50)).await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
}
}

View file

@ -68,7 +68,11 @@ where
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
U: ServiceFactory<Config = (), Request = (Request, IoState, Codec), Response = ()>,
U: ServiceFactory<
Config = (),
Request = (Request, TcpStream, IoState, Codec),
Response = (),
>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
@ -112,7 +116,7 @@ mod openssl {
X::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, IoState, Codec),
Request = (Request, SslStream<TcpStream>, IoState, Codec),
Response = (),
>,
U::Error: fmt::Display + Error + 'static,
@ -166,7 +170,7 @@ mod rustls {
X::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, IoState, Codec),
Request = (Request, TlsStream<TcpStream>, IoState, Codec),
Response = (),
>,
U::Error: fmt::Display + Error + 'static,
@ -228,7 +232,7 @@ where
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, S, B, X, U1>
where
U1: ServiceFactory<Request = (Request, IoState, Codec), Response = ()>,
U1: ServiceFactory<Request = (Request, T, IoState, Codec), Response = ()>,
U1::Error: fmt::Display + Error + 'static,
U1::InitError: fmt::Debug,
U1::Future: 'static,
@ -267,7 +271,11 @@ where
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
U: ServiceFactory<Config = (), Request = (Request, IoState, Codec), Response = ()>,
U: ServiceFactory<
Config = (),
Request = (Request, T, IoState, Codec),
Response = (),
>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
@ -331,13 +339,13 @@ where
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, IoState, Codec), Response = ()>,
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
U::Error: fmt::Display + Error + 'static,
{
type Request = (T, Option<net::SocketAddr>);
type Response = ();
type Error = DispatchError;
type Future = Dispatcher<S, B, X, U>;
type Future = Dispatcher<T, S, B, X, U>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let cfg = self.config.as_ref();

View file

@ -11,7 +11,7 @@ pub struct UpgradeHandler<T>(PhantomData<T>);
impl<T> ServiceFactory for UpgradeHandler<T> {
type Config = ();
type Request = (Request, State, Codec);
type Request = (Request, T, State, Codec);
type Response = ();
type Error = io::Error;
type Service = UpgradeHandler<T>;
@ -25,7 +25,7 @@ impl<T> ServiceFactory for UpgradeHandler<T> {
}
impl<T> Service for UpgradeHandler<T> {
type Request = (Request, State, Codec);
type Request = (Request, T, State, Codec);
type Response = ();
type Error = io::Error;
type Future = Ready<Result<Self::Response, Self::Error>>;

View file

@ -126,7 +126,7 @@ where
where
U1: ServiceFactory<
Config = (),
Request = (Request, State, h1::Codec),
Request = (Request, T, State, h1::Codec),
Response = (),
>,
U1::Error: fmt::Display + error::Error + 'static,
@ -167,7 +167,11 @@ where
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<Config = (), Request = (Request, State, h1::Codec), Response = ()>,
U: ServiceFactory<
Config = (),
Request = (Request, TcpStream, State, h1::Codec),
Response = (),
>,
U::Error: fmt::Display + error::Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
@ -213,7 +217,7 @@ mod openssl {
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, State, h1::Codec),
Request = (Request, SslStream<TcpStream>, State, h1::Codec),
Response = (),
>,
U::Error: fmt::Display + error::Error + 'static,
@ -278,7 +282,7 @@ mod rustls {
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, State, h1::Codec),
Request = (Request, TlsStream<TcpStream>, State, h1::Codec),
Response = (),
>,
U::Error: fmt::Display + error::Error + 'static,
@ -342,7 +346,11 @@ where
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<Config = (), Request = (Request, State, h1::Codec), Response = ()>,
U: ServiceFactory<
Config = (),
Request = (Request, T, State, h1::Codec),
Response = (),
>,
U::Error: fmt::Display + error::Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
@ -410,7 +418,7 @@ where
B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
U::Error: fmt::Display + error::Error + 'static,
{
type Request = (T, Protocol, Option<net::SocketAddr>);
@ -513,6 +521,7 @@ pin_project_lite::pin_project! {
T: AsyncRead,
T: AsyncWrite,
T: Unpin,
T: 'static,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: 'static,
@ -523,7 +532,7 @@ pin_project_lite::pin_project! {
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: 'static,
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
U::Error: fmt::Display,
U::Error: error::Error,
U::Error: 'static,
@ -543,16 +552,17 @@ pin_project_lite::pin_project! {
T: AsyncRead,
T: AsyncWrite,
T: Unpin,
T: 'static,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: 'static,
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
U::Error: fmt::Display,
U::Error: error::Error,
U::Error: 'static,
{
H1 { #[pin] fut: h1::Dispatcher<S, B, X, U> },
H1 { #[pin] fut: h1::Dispatcher<T, S, B, X, U> },
H2 { fut: Dispatcher<T, S, B, X, U> },
H2Handshake { data:
Option<(
@ -567,7 +577,7 @@ pin_project_lite::pin_project! {
impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Future: 'static,
@ -575,7 +585,7 @@ where
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
U::Error: fmt::Display + error::Error + 'static,
{
type Output = Result<(), DispatchError>;

View file

@ -8,12 +8,12 @@ use ntex::framed::{DispatchItem, Dispatcher, State};
use ntex::http::test::server as test_server;
use ntex::http::ws::handshake_response;
use ntex::http::{body::BodySize, h1, HttpService, Request, Response};
use ntex::rt::net::TcpStream;
use ntex::ws;
async fn ws_service(
msg: DispatchItem<ws::Codec>,
) -> Result<Option<ws::Message>, io::Error> {
println!("TEST: {:?}", msg);
let msg = match msg {
DispatchItem::Item(msg) => match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
@ -31,33 +31,33 @@ async fn ws_service(
#[ntex::test]
async fn test_simple() {
std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace");
env_logger::init();
let mut srv = test_server(|| {
HttpService::build()
.upgrade(|(req, state, mut codec): (Request, State, h1::Codec)| {
async move {
let res = handshake_response(req.head()).finish();
.upgrade(
|(req, io, state, mut codec): (Request, TcpStream, State, h1::Codec)| {
async move {
let res = handshake_response(req.head()).finish();
// send handshake respone
state
.write_item(
h1::Message::Item((res.drop_body(), BodySize::None)),
&mut codec,
// send handshake respone
state
.write_item(
h1::Message::Item((res.drop_body(), BodySize::None)),
&mut codec,
)
.unwrap();
// start websocket service
Dispatcher::new(
io,
ws::Codec::default(),
state,
ws_service,
Default::default(),
)
.unwrap();
// start websocket service
Dispatcher::from_state(
ws::Codec::default(),
state,
ws_service,
Default::default(),
)
.await
}
})
.await
}
},
)
.finish(|_| ok::<_, io::Error>(Response::NotFound()))
.tcp()
});

View file

@ -85,6 +85,9 @@ async fn test_expect_continue() {
#[ntex::test]
async fn test_expect_continue_h1() {
std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace");
env_logger::init();
let srv = test_server(|| {
HttpService::build()
.expect(fn_service(|req: Request| {
@ -115,7 +118,9 @@ async fn test_expect_continue_h1() {
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
let mut data = String::new();
println!("1-------------------");
let _ = stream.read_to_string(&mut data);
println!("2-------------------");
assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
}

View file

@ -1,20 +1,21 @@
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::{cell::Cell, io, pin::Pin};
use std::{cell::Cell, io, marker::PhantomData, pin::Pin};
use bytes::Bytes;
use futures::{future, Future, SinkExt, StreamExt};
use ntex::codec::{AsyncRead, AsyncWrite};
use ntex::framed::{DispatchItem, Dispatcher, State, Timer};
use ntex::http::{body, h1, test, ws::handshake, HttpService, Request, Response};
use ntex::service::{fn_factory, Service};
use ntex::ws;
struct WsService(Arc<Mutex<Cell<bool>>>);
struct WsService<T>(Arc<Mutex<Cell<bool>>>, PhantomData<T>);
impl WsService {
impl<T> WsService<T> {
fn new() -> Self {
WsService(Arc::new(Mutex::new(Cell::new(false))))
WsService(Arc::new(Mutex::new(Cell::new(false))), PhantomData)
}
fn set_polled(&self) {
@ -26,14 +27,17 @@ impl WsService {
}
}
impl Clone for WsService {
impl<T> Clone for WsService<T> {
fn clone(&self) -> Self {
WsService(self.0.clone())
WsService(self.0.clone(), PhantomData)
}
}
impl Service for WsService {
type Request = (Request, State, h1::Codec);
impl<T> Service for WsService<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Request = (Request, T, State, h1::Codec);
type Response = ();
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<(), io::Error>>>>;
@ -43,7 +47,7 @@ impl Service for WsService {
Poll::Ready(Ok(()))
}
fn call(&self, (req, state, mut codec): Self::Request) -> Self::Future {
fn call(&self, (req, io, state, mut codec): Self::Request) -> Self::Future {
let fut = async move {
let res = handshake(req.head()).unwrap().message_body(());
@ -51,7 +55,7 @@ impl Service for WsService {
.write_item((res, body::BodySize::None).into(), &mut codec)
.unwrap();
Dispatcher::from_state(ws::Codec::new(), state, service, Timer::default())
Dispatcher::new(io, ws::Codec::new(), state, service, Timer::default())
.await
.map_err(|_| panic!())
};