mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-04 13:27:39 +03:00
http: Pass io stream to upgrade handler
This commit is contained in:
parent
c47ec4ae25
commit
f0fe2bbc59
13 changed files with 667 additions and 164 deletions
|
@ -1,5 +1,9 @@
|
|||
# Changes
|
||||
|
||||
## [0.2.0-b.6] - 2021-01-24
|
||||
|
||||
* http: Pass io stream to upgrade handler
|
||||
|
||||
## [0.2.0-b.5] - 2021-01-23
|
||||
|
||||
* accept shared ref in some methods of framed::State type
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex"
|
||||
version = "0.2.0-b.5"
|
||||
version = "0.2.0-b.6"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "Framework for composable network services"
|
||||
readme = "README.md"
|
||||
|
|
|
@ -39,6 +39,9 @@ where
|
|||
if self.state.is_io_shutdown() {
|
||||
log::trace!("read task is instructed to shutdown");
|
||||
Poll::Ready(())
|
||||
} else if self.state.is_io_stop() {
|
||||
self.state.dsp_wake_task();
|
||||
Poll::Ready(())
|
||||
} else if self.state.is_read_paused() {
|
||||
self.state.register_read_task(cx.waker());
|
||||
Poll::Pending
|
||||
|
|
|
@ -13,22 +13,26 @@ use crate::task::LocalWaker;
|
|||
const HW: usize = 8 * 1024;
|
||||
|
||||
bitflags::bitflags! {
|
||||
pub struct Flags: u8 {
|
||||
pub struct Flags: u16 {
|
||||
const DSP_STOP = 0b0000_0001;
|
||||
const DSP_KEEPALIVE = 0b0000_0010;
|
||||
|
||||
/// io error occured
|
||||
const IO_ERR = 0b0000_0100;
|
||||
const IO_SHUTDOWN = 0b0000_1000;
|
||||
/// stop io tasks
|
||||
const IO_STOP = 0b0000_1000;
|
||||
/// shutdown io tasks
|
||||
const IO_SHUTDOWN = 0b0001_0000;
|
||||
|
||||
/// pause io read
|
||||
const RD_PAUSED = 0b0001_0000;
|
||||
const RD_PAUSED = 0b0010_0000;
|
||||
/// new data is available
|
||||
const RD_READY = 0b0010_0000;
|
||||
const RD_READY = 0b0100_0000;
|
||||
|
||||
/// write buffer is full
|
||||
const WR_NOT_READY = 0b0100_0000;
|
||||
const WR_NOT_READY = 0b1000_0000;
|
||||
|
||||
const ST_DSP_ERR = 0b1000_0000;
|
||||
const ST_DSP_ERR = 0b10000_0000;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,6 +152,11 @@ impl State {
|
|||
.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_io_stop(&self) -> bool {
|
||||
self.0.flags.get().contains(Flags::IO_STOP)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if read buffer has new data
|
||||
pub fn is_read_ready(&self) -> bool {
|
||||
|
@ -317,6 +326,17 @@ impl State {
|
|||
self.0.dispatch_task.register(waker);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Stop io tasks
|
||||
pub fn dsp_stop_io(&self, waker: &Waker) {
|
||||
let mut flags = self.0.flags.get();
|
||||
flags.insert(Flags::IO_STOP);
|
||||
self.0.flags.set(flags);
|
||||
self.0.read_task.wake();
|
||||
self.0.write_task.wake();
|
||||
self.0.dispatch_task.register(waker);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Wake dispatcher
|
||||
pub fn dsp_wake_task(&self) {
|
||||
|
@ -329,6 +349,14 @@ impl State {
|
|||
self.0.dispatch_task.register(waker);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Reset io stop flags
|
||||
pub fn reset_io_stop(&self) {
|
||||
let mut flags = self.0.flags.get();
|
||||
flags.remove(Flags::IO_STOP);
|
||||
self.0.flags.set(flags);
|
||||
}
|
||||
|
||||
fn mark_io_error(&self) {
|
||||
self.0.read_task.wake();
|
||||
self.0.write_task.wake();
|
||||
|
|
|
@ -74,6 +74,9 @@ where
|
|||
if this.state.is_io_err() {
|
||||
log::trace!("write io is closed");
|
||||
return Poll::Ready(());
|
||||
} else if this.state.is_io_stop() {
|
||||
self.state.dsp_wake_task();
|
||||
return Poll::Ready(());
|
||||
}
|
||||
|
||||
match this.st {
|
||||
|
@ -224,7 +227,7 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
// log::trace!("flushed {} bytes", written);
|
||||
log::trace!("flushed {} bytes", written);
|
||||
|
||||
// remove written data
|
||||
if written == len {
|
||||
|
|
|
@ -55,7 +55,7 @@ where
|
|||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<Config = (), Request = (Request, State, Codec), Response = ()>,
|
||||
U: ServiceFactory<Config = (), Request = (Request, T, State, Codec), Response = ()>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
|
@ -142,7 +142,7 @@ where
|
|||
F: IntoServiceFactory<U1>,
|
||||
U1: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, State, Codec),
|
||||
Request = (Request, T, State, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U1::Error: fmt::Display + Error + 'static,
|
||||
|
|
|
@ -1,21 +1,20 @@
|
|||
//! Framed transport dispatcher
|
||||
use std::error::Error;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{
|
||||
cell::RefCell, fmt, future::Future, marker::PhantomData, net, pin::Pin, rc::Rc,
|
||||
time::Duration, time::Instant,
|
||||
cell::RefCell, error::Error, fmt, marker::PhantomData, net, pin::Pin, rc::Rc, time,
|
||||
};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::Future;
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite, Decoder};
|
||||
use crate::codec::{AsyncRead, AsyncWrite};
|
||||
use crate::framed::{ReadTask, State as IoState, WriteTask};
|
||||
use crate::service::Service;
|
||||
|
||||
use crate::http;
|
||||
use crate::http::body::{BodySize, MessageBody, ResponseBody};
|
||||
use crate::http::config::DispatcherConfig;
|
||||
use crate::http::error::{DispatchError, PayloadError, ResponseError};
|
||||
use crate::http::error::{DispatchError, ParseError, PayloadError, ResponseError};
|
||||
use crate::http::helpers::DataFactory;
|
||||
use crate::http::request::Request;
|
||||
use crate::http::response::Response;
|
||||
|
@ -37,11 +36,11 @@ bitflags::bitflags! {
|
|||
|
||||
pin_project_lite::pin_project! {
|
||||
/// Dispatcher for HTTP/1.1 protocol
|
||||
pub struct Dispatcher<S: Service, B, X: Service, U: Service> {
|
||||
pub struct Dispatcher<T, S: Service, B, X: Service, U: Service> {
|
||||
#[pin]
|
||||
call: CallState<S, X, U>,
|
||||
st: State<B>,
|
||||
inner: DispatcherInner<S, B, X, U>,
|
||||
inner: DispatcherInner<T, S, B, X, U>,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,6 +49,7 @@ enum State<B> {
|
|||
ReadRequest,
|
||||
ReadPayload,
|
||||
SendPayload { body: ResponseBody<B> },
|
||||
Upgrade(Option<Request>),
|
||||
Stop,
|
||||
}
|
||||
|
||||
|
@ -63,12 +63,13 @@ pin_project_lite::pin_project! {
|
|||
}
|
||||
}
|
||||
|
||||
struct DispatcherInner<S, B, X, U> {
|
||||
struct DispatcherInner<T, S, B, X, U> {
|
||||
io: Option<Rc<RefCell<T>>>,
|
||||
flags: Flags,
|
||||
codec: Codec,
|
||||
config: Rc<DispatcherConfig<S, X, U>>,
|
||||
state: IoState,
|
||||
expire: Instant,
|
||||
expire: time::Instant,
|
||||
error: Option<DispatchError>,
|
||||
payload: Option<(PayloadDecoder, PayloadSender)>,
|
||||
peer_addr: Option<net::SocketAddr>,
|
||||
|
@ -77,34 +78,38 @@ struct DispatcherInner<S, B, X, U> {
|
|||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq)]
|
||||
enum PollPayloadStatus {
|
||||
enum ReadPayloadStatus {
|
||||
Done,
|
||||
Updated,
|
||||
Pending,
|
||||
Dropped,
|
||||
}
|
||||
|
||||
impl<S, B, X, U> Dispatcher<S, B, X, U>
|
||||
enum WritePayloadStatus<B> {
|
||||
Next(State<B>),
|
||||
Pause,
|
||||
Continue,
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Response: Into<Response<B>>,
|
||||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError,
|
||||
U: Service<Request = (Request, IoState, Codec), Response = ()>,
|
||||
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
|
||||
U::Error: Error + fmt::Display,
|
||||
{
|
||||
/// Construct new `Dispatcher` instance with outgoing messages stream.
|
||||
pub(in crate::http) fn new<T>(
|
||||
pub(in crate::http) fn new(
|
||||
io: T,
|
||||
config: Rc<DispatcherConfig<S, X, U>>,
|
||||
peer_addr: Option<net::SocketAddr>,
|
||||
on_connect_data: Option<Box<dyn DataFactory>>,
|
||||
) -> Self
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
) -> Self {
|
||||
let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled());
|
||||
|
||||
let state = IoState::new();
|
||||
|
@ -115,18 +120,19 @@ where
|
|||
|
||||
// slow-request timer
|
||||
if config.client_timeout != 0 {
|
||||
expire += Duration::from_secs(config.client_timeout);
|
||||
expire += time::Duration::from_secs(config.client_timeout);
|
||||
config.timer_h1.register(expire, expire, &state);
|
||||
}
|
||||
|
||||
// start support io tasks
|
||||
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
|
||||
crate::rt::spawn(WriteTask::new(io, state.clone()));
|
||||
crate::rt::spawn(WriteTask::new(io.clone(), state.clone()));
|
||||
|
||||
Dispatcher {
|
||||
call: CallState::None,
|
||||
st: State::ReadRequest,
|
||||
inner: DispatcherInner {
|
||||
io: Some(io),
|
||||
flags: Flags::empty(),
|
||||
error: None,
|
||||
payload: None,
|
||||
|
@ -142,15 +148,16 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<S, B, X, U> Future for Dispatcher<S, B, X, U>
|
||||
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Response: Into<Response<B>>,
|
||||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
U: Service<Request = (Request, IoState, Codec), Response = ()>,
|
||||
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
|
||||
U::Error: Error + fmt::Display + 'static,
|
||||
{
|
||||
type Output = Result<(), DispatchError>;
|
||||
|
@ -180,7 +187,7 @@ where
|
|||
// we might need to read more data into a request payload
|
||||
// (ie service future can wait for payload data)
|
||||
if this.inner.poll_read_payload(cx)
|
||||
!= PollPayloadStatus::Updated
|
||||
!= ReadPayloadStatus::Updated
|
||||
{
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
@ -197,27 +204,16 @@ where
|
|||
b"HTTP/1.1 100 Continue\r\n\r\n",
|
||||
)
|
||||
});
|
||||
Some(if this.inner.flags.contains(Flags::UPGRADE) {
|
||||
// Handle UPGRADE request
|
||||
CallState::Upgrade {
|
||||
fut: this
|
||||
.inner
|
||||
.config
|
||||
.upgrade
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.call((
|
||||
req,
|
||||
this.inner.state.clone(),
|
||||
this.inner.codec.clone(),
|
||||
)),
|
||||
}
|
||||
if this.inner.flags.contains(Flags::UPGRADE) {
|
||||
this.inner.state.dsp_stop_io(cx.waker());
|
||||
*this.st = State::Upgrade(Some(req));
|
||||
return Poll::Pending;
|
||||
} else {
|
||||
CallState::Service {
|
||||
Some(CallState::Service {
|
||||
fut: this.inner.config.service.call(req),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
*this.st = this.inner.handle_error(e, true);
|
||||
None
|
||||
|
@ -273,7 +269,11 @@ where
|
|||
if this.inner.state.is_read_ready() {
|
||||
match this.inner.state.decode_item(&this.inner.codec) {
|
||||
Ok(Some((mut req, pl))) => {
|
||||
log::trace!("http message is received: {:?}", req);
|
||||
log::trace!(
|
||||
"http message is received: {:?} and payload {:?}",
|
||||
req,
|
||||
pl
|
||||
);
|
||||
req.head_mut().peer_addr = this.inner.peer_addr;
|
||||
|
||||
// configure request payload
|
||||
|
@ -313,48 +313,46 @@ where
|
|||
on_connect.set(&mut req.extensions_mut());
|
||||
}
|
||||
|
||||
if req.head().expect() {
|
||||
// call service
|
||||
*this.st = State::Call;
|
||||
this.call.set(if req.head().expect() {
|
||||
// Handle `EXPECT: 100-Continue` header
|
||||
CallState::Expect {
|
||||
this.call.set(CallState::Expect {
|
||||
fut: this.inner.config.expect.call(req),
|
||||
}
|
||||
});
|
||||
} else if upgrade {
|
||||
log::trace!("initate upgrade handling");
|
||||
log::trace!("prep io for upgrade handler");
|
||||
// Handle UPGRADE request
|
||||
CallState::Upgrade {
|
||||
fut: this
|
||||
.inner
|
||||
.config
|
||||
.upgrade
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.call((
|
||||
req,
|
||||
this.inner.state.clone(),
|
||||
this.inner.codec.clone(),
|
||||
)),
|
||||
}
|
||||
this.inner.state.dsp_stop_io(cx.waker());
|
||||
*this.st = State::Upgrade(Some(req));
|
||||
return Poll::Pending;
|
||||
} else {
|
||||
// Handle normal requests
|
||||
CallState::Service {
|
||||
*this.st = State::Call;
|
||||
this.call.set(CallState::Service {
|
||||
fut: this.inner.config.service.call(req),
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// if connection is not keep-alive then disconnect
|
||||
log::trace!("not enough data to decode next frame, register dispatch task");
|
||||
|
||||
// if io error occured or connection is not keep-alive
|
||||
// then disconnect
|
||||
if this.inner.flags.contains(Flags::STARTED)
|
||||
&& !this.inner.flags.contains(Flags::KEEPALIVE)
|
||||
&& (!this.inner.flags.contains(Flags::KEEPALIVE)
|
||||
|| !this.inner.codec.keepalive_enabled()
|
||||
|| this.inner.state.is_io_err())
|
||||
{
|
||||
*this.st = State::Stop;
|
||||
this.inner.state.dsp_mark_stopped();
|
||||
continue;
|
||||
}
|
||||
this.inner.state.dsp_read_more_data(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
Err(err) => {
|
||||
log::trace!("malformed request: {:?}", err);
|
||||
// Malformed requests, respond with 400
|
||||
let (res, body) =
|
||||
Response::BadRequest().finish().into_parts();
|
||||
|
@ -379,31 +377,70 @@ where
|
|||
// consume request's payload
|
||||
State::ReadPayload => loop {
|
||||
match this.inner.poll_read_payload(cx) {
|
||||
PollPayloadStatus::Updated => continue,
|
||||
PollPayloadStatus::Pending => return Poll::Pending,
|
||||
PollPayloadStatus::Done => {
|
||||
ReadPayloadStatus::Updated => continue,
|
||||
ReadPayloadStatus::Pending => return Poll::Pending,
|
||||
ReadPayloadStatus::Done => {
|
||||
*this.st = {
|
||||
this.inner.reset_keepalive();
|
||||
State::ReadRequest
|
||||
}
|
||||
}
|
||||
PollPayloadStatus::Dropped => *this.st = State::Stop,
|
||||
ReadPayloadStatus::Dropped => *this.st = State::Stop,
|
||||
}
|
||||
break;
|
||||
},
|
||||
// send response body
|
||||
State::SendPayload { ref mut body } => {
|
||||
if this.inner.state.is_io_err() {
|
||||
*this.st = State::Stop;
|
||||
} else {
|
||||
this.inner.poll_read_payload(cx);
|
||||
|
||||
match body.poll_next_chunk(cx) {
|
||||
Poll::Ready(item) => {
|
||||
if let Some(st) = this.inner.send_payload(item) {
|
||||
Poll::Ready(item) => match this.inner.send_payload(item) {
|
||||
WritePayloadStatus::Next(st) => {
|
||||
*this.st = st;
|
||||
}
|
||||
WritePayloadStatus::Pause => {
|
||||
this.inner.state.dsp_flush_write_data(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
WritePayloadStatus::Continue => (),
|
||||
},
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
// stop io tasks and call upgrade service
|
||||
State::Upgrade(ref mut req) => {
|
||||
// check if all io tasks have been stopped
|
||||
let io = if Rc::strong_count(this.inner.io.as_ref().unwrap()) == 1 {
|
||||
if let Ok(io) = Rc::try_unwrap(this.inner.io.take().unwrap()) {
|
||||
io.into_inner()
|
||||
} else {
|
||||
return Poll::Ready(Err(DispatchError::InternalError));
|
||||
}
|
||||
} else {
|
||||
// wait next task stop
|
||||
this.inner.state.dsp_register_task(cx.waker());
|
||||
return Poll::Pending;
|
||||
};
|
||||
log::trace!("initate upgrade handling");
|
||||
|
||||
let req = req.take().unwrap();
|
||||
*this.st = State::Call;
|
||||
this.inner.state.reset_io_stop();
|
||||
|
||||
// Handle UPGRADE request
|
||||
this.call.set(CallState::Upgrade {
|
||||
fut: this.inner.config.upgrade.as_ref().unwrap().call((
|
||||
req,
|
||||
io,
|
||||
this.inner.state.clone(),
|
||||
this.inner.codec.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
// prepare to shutdown
|
||||
State::Stop => {
|
||||
this.inner.state.shutdown_io();
|
||||
|
@ -426,7 +463,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<S, B, X, U> DispatcherInner<S, B, X, U>
|
||||
impl<T, S, B, X, U> DispatcherInner<T, S, B, X, U>
|
||||
where
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
|
@ -442,8 +479,8 @@ where
|
|||
fn reset_keepalive(&mut self) {
|
||||
// re-register keep-alive
|
||||
if self.flags.contains(Flags::KEEPALIVE) {
|
||||
let expire =
|
||||
self.config.timer_h1.now() + Duration::from_secs(self.config.keep_alive);
|
||||
let expire = self.config.timer_h1.now()
|
||||
+ time::Duration::from_secs(self.config.keep_alive);
|
||||
self.config
|
||||
.timer_h1
|
||||
.register(expire, self.expire, &self.state);
|
||||
|
@ -512,18 +549,25 @@ where
|
|||
fn send_payload(
|
||||
&mut self,
|
||||
item: Option<Result<Bytes, Box<dyn Error>>>,
|
||||
) -> Option<State<B>> {
|
||||
) -> WritePayloadStatus<B> {
|
||||
match item {
|
||||
Some(Ok(item)) => {
|
||||
trace!("Got response chunk: {:?}", item.len());
|
||||
if let Err(err) = self
|
||||
match self
|
||||
.state
|
||||
.write_item(Message::Chunk(Some(item)), &self.codec)
|
||||
{
|
||||
Err(err) => {
|
||||
self.error = Some(DispatchError::Encode(err));
|
||||
Some(State::Stop)
|
||||
WritePayloadStatus::Next(State::Stop)
|
||||
}
|
||||
Ok(has_space) => {
|
||||
if has_space {
|
||||
WritePayloadStatus::Continue
|
||||
} else {
|
||||
None
|
||||
WritePayloadStatus::Pause
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
|
@ -532,24 +576,24 @@ where
|
|||
self.state.write_item(Message::Chunk(None), &self.codec)
|
||||
{
|
||||
self.error = Some(DispatchError::Encode(err));
|
||||
Some(State::Stop)
|
||||
WritePayloadStatus::Next(State::Stop)
|
||||
} else if self.payload.is_some() {
|
||||
Some(State::ReadPayload)
|
||||
WritePayloadStatus::Next(State::ReadPayload)
|
||||
} else {
|
||||
self.reset_keepalive();
|
||||
Some(State::ReadRequest)
|
||||
WritePayloadStatus::Next(State::ReadRequest)
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
trace!("Error during response body poll: {:?}", e);
|
||||
self.error = Some(DispatchError::ResponsePayload(e));
|
||||
Some(State::Stop)
|
||||
WritePayloadStatus::Next(State::Stop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process request's payload
|
||||
fn poll_read_payload(&mut self, cx: &mut Context<'_>) -> PollPayloadStatus {
|
||||
fn poll_read_payload(&mut self, cx: &mut Context<'_>) -> ReadPayloadStatus {
|
||||
// check if payload data is required
|
||||
if let Some(ref mut payload) = self.payload {
|
||||
match payload.1.poll_data_required(cx) {
|
||||
|
@ -557,7 +601,7 @@ where
|
|||
// read request payload
|
||||
let mut updated = false;
|
||||
loop {
|
||||
let item = self.state.with_read_buf(|buf| payload.0.decode(buf));
|
||||
let item = self.state.decode_item(&payload.0);
|
||||
match item {
|
||||
Ok(Some(PayloadItem::Chunk(chunk))) => {
|
||||
updated = true;
|
||||
|
@ -567,40 +611,434 @@ where
|
|||
payload.1.feed_eof();
|
||||
self.payload = None;
|
||||
if !updated {
|
||||
return PollPayloadStatus::Done;
|
||||
return ReadPayloadStatus::Done;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ok(None) => {
|
||||
if self.state.is_io_err() {
|
||||
payload.1.set_error(PayloadError::EncodingCorrupted);
|
||||
self.payload = None;
|
||||
self.error = Some(ParseError::Incomplete.into());
|
||||
return ReadPayloadStatus::Dropped;
|
||||
} else {
|
||||
self.state.dsp_read_more_data(cx.waker());
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
payload.1.set_error(PayloadError::EncodingCorrupted);
|
||||
self.payload = None;
|
||||
self.error = Some(DispatchError::Parse(e));
|
||||
return PollPayloadStatus::Dropped;
|
||||
return ReadPayloadStatus::Dropped;
|
||||
}
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
PollPayloadStatus::Updated
|
||||
ReadPayloadStatus::Updated
|
||||
} else {
|
||||
PollPayloadStatus::Pending
|
||||
ReadPayloadStatus::Pending
|
||||
}
|
||||
}
|
||||
PayloadStatus::Pause => PollPayloadStatus::Pending,
|
||||
PayloadStatus::Pause => ReadPayloadStatus::Pending,
|
||||
PayloadStatus::Dropped => {
|
||||
// service call is not interested in payload
|
||||
// wait until future completes and then close
|
||||
// connection
|
||||
self.payload = None;
|
||||
self.error = Some(DispatchError::PayloadIsNotConsumed);
|
||||
PollPayloadStatus::Dropped
|
||||
ReadPayloadStatus::Dropped
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PollPayloadStatus::Done
|
||||
ReadPayloadStatus::Done
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::{io, sync::Arc};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::future::{lazy, ok, FutureExt};
|
||||
use futures::StreamExt;
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
use crate::codec::Decoder;
|
||||
use crate::http::config::{DispatcherConfig, ServiceConfig};
|
||||
use crate::http::h1::{ClientCodec, ExpectHandler, UpgradeHandler};
|
||||
use crate::http::{body, Request, ResponseHead, StatusCode};
|
||||
use crate::rt::time::delay_for;
|
||||
use crate::service::IntoService;
|
||||
use crate::testing::Io;
|
||||
|
||||
const BUFFER_SIZE: usize = 32_768;
|
||||
|
||||
/// Create http/1 dispatcher.
|
||||
pub(crate) fn h1<F, S, B>(
|
||||
stream: Io,
|
||||
service: F,
|
||||
) -> Dispatcher<Io, S, B, ExpectHandler, UpgradeHandler<Io>>
|
||||
where
|
||||
F: IntoService<S>,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Response: Into<Response<B>>,
|
||||
B: MessageBody,
|
||||
{
|
||||
Dispatcher::new(
|
||||
stream,
|
||||
Rc::new(DispatcherConfig::new(
|
||||
ServiceConfig::default(),
|
||||
service.into_service(),
|
||||
ExpectHandler,
|
||||
None,
|
||||
)),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn spawn_h1<F, S, B>(stream: Io, service: F)
|
||||
where
|
||||
F: IntoService<S>,
|
||||
S: Service<Request = Request> + 'static,
|
||||
S::Error: ResponseError,
|
||||
S::Response: Into<Response<B>>,
|
||||
B: MessageBody + 'static,
|
||||
{
|
||||
crate::rt::spawn(
|
||||
Dispatcher::<Io, S, B, ExpectHandler, UpgradeHandler<Io>>::new(
|
||||
stream,
|
||||
Rc::new(DispatcherConfig::new(
|
||||
ServiceConfig::default(),
|
||||
service.into_service(),
|
||||
ExpectHandler,
|
||||
None,
|
||||
)),
|
||||
None,
|
||||
None,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
fn load(decoder: &mut ClientCodec, buf: &mut BytesMut) -> ResponseHead {
|
||||
decoder.decode(buf).unwrap().unwrap()
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_req_parse_err() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish()));
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
|
||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
|
||||
assert!(!h1.inner.state.is_open());
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
|
||||
client
|
||||
.local_buffer(|buf| assert_eq!(&buf[..26], b"HTTP/1.1 400 Bad Request\r\n"));
|
||||
|
||||
client.close().await;
|
||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
|
||||
// assert!(h1.inner.flags.contains(Flags::SHUTDOWN_IO));
|
||||
assert!(h1.inner.state.is_io_err());
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_pipeline() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(4096);
|
||||
let mut decoder = ClientCodec::default();
|
||||
spawn_h1(server, |_| ok::<_, io::Error>(Response::Ok().finish()));
|
||||
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
|
||||
let mut buf = client.read().await.unwrap();
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
assert!(!client.is_server_dropped());
|
||||
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
|
||||
let mut buf = client.read().await.unwrap();
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
assert!(decoder.decode(&mut buf).unwrap().is_none());
|
||||
assert!(!client.is_server_dropped());
|
||||
|
||||
client.close().await;
|
||||
assert!(client.is_server_dropped());
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_pipeline_with_payload() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(4096);
|
||||
let mut decoder = ClientCodec::default();
|
||||
spawn_h1(server, |mut req: Request| async move {
|
||||
let mut p = req.take_payload();
|
||||
while let Some(_) = p.next().await {}
|
||||
Ok::<_, io::Error>(Response::Ok().finish())
|
||||
});
|
||||
|
||||
client.write("GET /test HTTP/1.1\r\ncontent-length: 5\r\n\r\n");
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
client.write("xxxxx");
|
||||
|
||||
let mut buf = client.read().await.unwrap();
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
assert!(!client.is_server_dropped());
|
||||
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
|
||||
let mut buf = client.read().await.unwrap();
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
assert!(decoder.decode(&mut buf).unwrap().is_none());
|
||||
assert!(!client.is_server_dropped());
|
||||
|
||||
client.close().await;
|
||||
assert!(client.is_server_dropped());
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_pipeline_with_delay() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(4096);
|
||||
let mut decoder = ClientCodec::default();
|
||||
spawn_h1(server, |_| async {
|
||||
delay_for(time::Duration::from_millis(100)).await;
|
||||
Ok::<_, io::Error>(Response::Ok().finish())
|
||||
});
|
||||
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
|
||||
let mut buf = client.read().await.unwrap();
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
assert!(!client.is_server_dropped());
|
||||
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
|
||||
let mut buf = client.read().await.unwrap();
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
|
||||
let mut buf = client.read().await.unwrap();
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
assert!(decoder.decode(&mut buf).unwrap().is_none());
|
||||
assert!(!client.is_server_dropped());
|
||||
|
||||
buf.extend(client.read().await.unwrap());
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
assert!(decoder.decode(&mut buf).unwrap().is_none());
|
||||
assert!(!client.is_server_dropped());
|
||||
|
||||
client.close().await;
|
||||
assert!(client.is_server_dropped());
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
/// if socket is disconnected, h1 dispatcher does not process any data
|
||||
// /// h1 dispatcher still processes all incoming requests
|
||||
// /// but it does not write any data to socket
|
||||
async fn test_write_disconnected() {
|
||||
let num = Arc::new(AtomicUsize::new(0));
|
||||
let num2 = num.clone();
|
||||
|
||||
let (client, server) = Io::create();
|
||||
spawn_h1(server, move |_| {
|
||||
num2.fetch_add(1, Ordering::Relaxed);
|
||||
ok::<_, io::Error>(Response::Ok().finish())
|
||||
});
|
||||
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
client.close().await;
|
||||
assert!(client.is_server_dropped());
|
||||
assert!(client.read_any().is_empty());
|
||||
|
||||
// only first request get handled
|
||||
assert_eq!(num.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_read_large_message() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(4096);
|
||||
|
||||
let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish()));
|
||||
let mut decoder = ClientCodec::default();
|
||||
|
||||
let data = rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(70_000)
|
||||
.map(char::from)
|
||||
.collect::<String>();
|
||||
client.write("GET /test HTTP/1.1\r\nContent-Length: ");
|
||||
client.write(data);
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
|
||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
|
||||
assert!(!h1.inner.state.is_open());
|
||||
|
||||
let mut buf = client.read().await.unwrap();
|
||||
assert_eq!(load(&mut decoder, &mut buf).status, StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_read_backpressure() {
|
||||
let mark = Arc::new(AtomicBool::new(false));
|
||||
let mark2 = mark.clone();
|
||||
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(4096);
|
||||
spawn_h1(server, move |mut req: Request| {
|
||||
let m = mark2.clone();
|
||||
async move {
|
||||
// read one chunk
|
||||
let mut pl = req.take_payload();
|
||||
let _ = pl.next().await.unwrap().unwrap();
|
||||
m.store(true, Ordering::Relaxed);
|
||||
// sleep
|
||||
delay_for(time::Duration::from_secs(999_999)).await;
|
||||
Ok::<_, io::Error>(Response::Ok().finish())
|
||||
}
|
||||
});
|
||||
|
||||
client.write("GET /test HTTP/1.1\r\nContent-Length: 1048576\r\n\r\n");
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
|
||||
// buf must be consumed
|
||||
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
|
||||
|
||||
// io should be drained only by no more than MAX_BUFFER_SIZE
|
||||
let random_bytes: Vec<u8> =
|
||||
(0..1_048_576).map(|_| rand::random::<u8>()).collect();
|
||||
client.write(random_bytes);
|
||||
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
assert!(client.remote_buffer(|buf| buf.len()) > 1_048_576 - BUFFER_SIZE * 3);
|
||||
assert!(mark.load(Ordering::Relaxed));
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_write_backpressure() {
|
||||
std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace");
|
||||
env_logger::init();
|
||||
|
||||
let num = Arc::new(AtomicUsize::new(0));
|
||||
let num2 = num.clone();
|
||||
|
||||
struct Stream(Arc<AtomicUsize>);
|
||||
|
||||
impl body::MessageBody for Stream {
|
||||
fn size(&self) -> body::BodySize {
|
||||
body::BodySize::Stream
|
||||
}
|
||||
fn poll_next_chunk(
|
||||
&mut self,
|
||||
_: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Bytes, Box<dyn std::error::Error>>>> {
|
||||
let data = rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(65_536)
|
||||
.map(char::from)
|
||||
.collect::<String>();
|
||||
self.0.fetch_add(data.len(), Ordering::Relaxed);
|
||||
|
||||
Poll::Ready(Some(Ok(Bytes::from(data))))
|
||||
}
|
||||
}
|
||||
|
||||
let (client, server) = Io::create();
|
||||
let mut h1 = h1(server, move |_| {
|
||||
let n = num2.clone();
|
||||
async move { Ok::<_, io::Error>(Response::Ok().message_body(Stream(n.clone()))) }
|
||||
.boxed_local()
|
||||
});
|
||||
let state = h1.inner.state.clone();
|
||||
|
||||
// do not allow to write to socket
|
||||
client.remote_buffer_cap(0);
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
|
||||
|
||||
// buf must be consumed
|
||||
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
|
||||
|
||||
// amount of generated data
|
||||
assert_eq!(num.load(Ordering::Relaxed), 65_536);
|
||||
|
||||
// response message + chunking encoding
|
||||
assert_eq!(state.with_write_buf(|buf| buf.len()), 65629);
|
||||
|
||||
client.remote_buffer_cap(65536);
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
assert_eq!(state.with_write_buf(|buf| buf.len()), 93);
|
||||
|
||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
|
||||
assert_eq!(num.load(Ordering::Relaxed), 65_536 * 2);
|
||||
}
|
||||
|
||||
#[ntex_rt::test]
|
||||
async fn test_disconnect_during_response_body_pending() {
|
||||
struct Stream(bool);
|
||||
|
||||
impl body::MessageBody for Stream {
|
||||
fn size(&self) -> body::BodySize {
|
||||
body::BodySize::Sized(2048)
|
||||
}
|
||||
fn poll_next_chunk(
|
||||
&mut self,
|
||||
_: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Bytes, Box<dyn std::error::Error>>>> {
|
||||
if self.0 {
|
||||
Poll::Pending
|
||||
} else {
|
||||
self.0 = true;
|
||||
let data = rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(1024)
|
||||
.map(char::from)
|
||||
.collect::<String>();
|
||||
Poll::Ready(Some(Ok(Bytes::from(data))))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(4096);
|
||||
let mut h1 = h1(server, |_| {
|
||||
ok::<_, io::Error>(Response::Ok().message_body(Stream(false)))
|
||||
});
|
||||
|
||||
client.write("GET /test HTTP/1.1\r\n\r\n");
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
|
||||
|
||||
// http message must be consumed
|
||||
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
|
||||
|
||||
let mut decoder = ClientCodec::default();
|
||||
let mut buf = client.read().await.unwrap();
|
||||
assert!(load(&mut decoder, &mut buf).status.is_success());
|
||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
|
||||
|
||||
client.close().await;
|
||||
delay_for(time::Duration::from_millis(50)).await;
|
||||
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,7 +68,11 @@ where
|
|||
X::Error: ResponseError + 'static,
|
||||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
U: ServiceFactory<Config = (), Request = (Request, IoState, Codec), Response = ()>,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, TcpStream, IoState, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
|
@ -112,7 +116,7 @@ mod openssl {
|
|||
X::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, IoState, Codec),
|
||||
Request = (Request, SslStream<TcpStream>, IoState, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
|
@ -166,7 +170,7 @@ mod rustls {
|
|||
X::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, IoState, Codec),
|
||||
Request = (Request, TlsStream<TcpStream>, IoState, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
|
@ -228,7 +232,7 @@ where
|
|||
|
||||
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, S, B, X, U1>
|
||||
where
|
||||
U1: ServiceFactory<Request = (Request, IoState, Codec), Response = ()>,
|
||||
U1: ServiceFactory<Request = (Request, T, IoState, Codec), Response = ()>,
|
||||
U1::Error: fmt::Display + Error + 'static,
|
||||
U1::InitError: fmt::Debug,
|
||||
U1::Future: 'static,
|
||||
|
@ -267,7 +271,11 @@ where
|
|||
X::Error: ResponseError + 'static,
|
||||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
U: ServiceFactory<Config = (), Request = (Request, IoState, Codec), Response = ()>,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, T, IoState, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
|
@ -331,13 +339,13 @@ where
|
|||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
U: Service<Request = (Request, IoState, Codec), Response = ()>,
|
||||
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
{
|
||||
type Request = (T, Option<net::SocketAddr>);
|
||||
type Response = ();
|
||||
type Error = DispatchError;
|
||||
type Future = Dispatcher<S, B, X, U>;
|
||||
type Future = Dispatcher<T, S, B, X, U>;
|
||||
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
let cfg = self.config.as_ref();
|
||||
|
|
|
@ -11,7 +11,7 @@ pub struct UpgradeHandler<T>(PhantomData<T>);
|
|||
|
||||
impl<T> ServiceFactory for UpgradeHandler<T> {
|
||||
type Config = ();
|
||||
type Request = (Request, State, Codec);
|
||||
type Request = (Request, T, State, Codec);
|
||||
type Response = ();
|
||||
type Error = io::Error;
|
||||
type Service = UpgradeHandler<T>;
|
||||
|
@ -25,7 +25,7 @@ impl<T> ServiceFactory for UpgradeHandler<T> {
|
|||
}
|
||||
|
||||
impl<T> Service for UpgradeHandler<T> {
|
||||
type Request = (Request, State, Codec);
|
||||
type Request = (Request, T, State, Codec);
|
||||
type Response = ();
|
||||
type Error = io::Error;
|
||||
type Future = Ready<Result<Self::Response, Self::Error>>;
|
||||
|
|
|
@ -126,7 +126,7 @@ where
|
|||
where
|
||||
U1: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, State, h1::Codec),
|
||||
Request = (Request, T, State, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U1::Error: fmt::Display + error::Error + 'static,
|
||||
|
@ -167,7 +167,11 @@ where
|
|||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<Config = (), Request = (Request, State, h1::Codec), Response = ()>,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, TcpStream, State, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
|
@ -213,7 +217,7 @@ mod openssl {
|
|||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, State, h1::Codec),
|
||||
Request = (Request, SslStream<TcpStream>, State, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
|
@ -278,7 +282,7 @@ mod rustls {
|
|||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, State, h1::Codec),
|
||||
Request = (Request, TlsStream<TcpStream>, State, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
|
@ -342,7 +346,11 @@ where
|
|||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<Config = (), Request = (Request, State, h1::Codec), Response = ()>,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, T, State, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
|
@ -410,7 +418,7 @@ where
|
|||
B: MessageBody + 'static,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
|
||||
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
{
|
||||
type Request = (T, Protocol, Option<net::SocketAddr>);
|
||||
|
@ -513,6 +521,7 @@ pin_project_lite::pin_project! {
|
|||
T: AsyncRead,
|
||||
T: AsyncWrite,
|
||||
T: Unpin,
|
||||
T: 'static,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError,
|
||||
S::Error: 'static,
|
||||
|
@ -523,7 +532,7 @@ pin_project_lite::pin_project! {
|
|||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError,
|
||||
X::Error: 'static,
|
||||
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
|
||||
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
|
||||
U::Error: fmt::Display,
|
||||
U::Error: error::Error,
|
||||
U::Error: 'static,
|
||||
|
@ -543,16 +552,17 @@ pin_project_lite::pin_project! {
|
|||
T: AsyncRead,
|
||||
T: AsyncWrite,
|
||||
T: Unpin,
|
||||
T: 'static,
|
||||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError,
|
||||
X::Error: 'static,
|
||||
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
|
||||
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
|
||||
U::Error: fmt::Display,
|
||||
U::Error: error::Error,
|
||||
U::Error: 'static,
|
||||
{
|
||||
H1 { #[pin] fut: h1::Dispatcher<S, B, X, U> },
|
||||
H1 { #[pin] fut: h1::Dispatcher<T, S, B, X, U> },
|
||||
H2 { fut: Dispatcher<T, S, B, X, U> },
|
||||
H2Handshake { data:
|
||||
Option<(
|
||||
|
@ -567,7 +577,7 @@ pin_project_lite::pin_project! {
|
|||
|
||||
impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Future: 'static,
|
||||
|
@ -575,7 +585,7 @@ where
|
|||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
|
||||
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
{
|
||||
type Output = Result<(), DispatchError>;
|
||||
|
|
|
@ -8,12 +8,12 @@ use ntex::framed::{DispatchItem, Dispatcher, State};
|
|||
use ntex::http::test::server as test_server;
|
||||
use ntex::http::ws::handshake_response;
|
||||
use ntex::http::{body::BodySize, h1, HttpService, Request, Response};
|
||||
use ntex::rt::net::TcpStream;
|
||||
use ntex::ws;
|
||||
|
||||
async fn ws_service(
|
||||
msg: DispatchItem<ws::Codec>,
|
||||
) -> Result<Option<ws::Message>, io::Error> {
|
||||
println!("TEST: {:?}", msg);
|
||||
let msg = match msg {
|
||||
DispatchItem::Item(msg) => match msg {
|
||||
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
|
||||
|
@ -31,12 +31,10 @@ async fn ws_service(
|
|||
|
||||
#[ntex::test]
|
||||
async fn test_simple() {
|
||||
std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace");
|
||||
env_logger::init();
|
||||
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.upgrade(|(req, state, mut codec): (Request, State, h1::Codec)| {
|
||||
.upgrade(
|
||||
|(req, io, state, mut codec): (Request, TcpStream, State, h1::Codec)| {
|
||||
async move {
|
||||
let res = handshake_response(req.head()).finish();
|
||||
|
||||
|
@ -49,7 +47,8 @@ async fn test_simple() {
|
|||
.unwrap();
|
||||
|
||||
// start websocket service
|
||||
Dispatcher::from_state(
|
||||
Dispatcher::new(
|
||||
io,
|
||||
ws::Codec::default(),
|
||||
state,
|
||||
ws_service,
|
||||
|
@ -57,7 +56,8 @@ async fn test_simple() {
|
|||
)
|
||||
.await
|
||||
}
|
||||
})
|
||||
},
|
||||
)
|
||||
.finish(|_| ok::<_, io::Error>(Response::NotFound()))
|
||||
.tcp()
|
||||
});
|
||||
|
|
|
@ -85,6 +85,9 @@ async fn test_expect_continue() {
|
|||
|
||||
#[ntex::test]
|
||||
async fn test_expect_continue_h1() {
|
||||
std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace");
|
||||
env_logger::init();
|
||||
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.expect(fn_service(|req: Request| {
|
||||
|
@ -115,7 +118,9 @@ async fn test_expect_continue_h1() {
|
|||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
let _ = stream.write_all(b"GET /test?yes= HTTP/1.1\r\nexpect: 100-continue\r\n\r\n");
|
||||
let mut data = String::new();
|
||||
println!("1-------------------");
|
||||
let _ = stream.read_to_string(&mut data);
|
||||
println!("2-------------------");
|
||||
assert!(data.starts_with("HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 200 OK\r\n"));
|
||||
}
|
||||
|
||||
|
|
|
@ -1,20 +1,21 @@
|
|||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll};
|
||||
use std::{cell::Cell, io, pin::Pin};
|
||||
use std::{cell::Cell, io, marker::PhantomData, pin::Pin};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{future, Future, SinkExt, StreamExt};
|
||||
|
||||
use ntex::codec::{AsyncRead, AsyncWrite};
|
||||
use ntex::framed::{DispatchItem, Dispatcher, State, Timer};
|
||||
use ntex::http::{body, h1, test, ws::handshake, HttpService, Request, Response};
|
||||
use ntex::service::{fn_factory, Service};
|
||||
use ntex::ws;
|
||||
|
||||
struct WsService(Arc<Mutex<Cell<bool>>>);
|
||||
struct WsService<T>(Arc<Mutex<Cell<bool>>>, PhantomData<T>);
|
||||
|
||||
impl WsService {
|
||||
impl<T> WsService<T> {
|
||||
fn new() -> Self {
|
||||
WsService(Arc::new(Mutex::new(Cell::new(false))))
|
||||
WsService(Arc::new(Mutex::new(Cell::new(false))), PhantomData)
|
||||
}
|
||||
|
||||
fn set_polled(&self) {
|
||||
|
@ -26,14 +27,17 @@ impl WsService {
|
|||
}
|
||||
}
|
||||
|
||||
impl Clone for WsService {
|
||||
impl<T> Clone for WsService<T> {
|
||||
fn clone(&self) -> Self {
|
||||
WsService(self.0.clone())
|
||||
WsService(self.0.clone(), PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
impl Service for WsService {
|
||||
type Request = (Request, State, h1::Codec);
|
||||
impl<T> Service for WsService<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
type Request = (Request, T, State, h1::Codec);
|
||||
type Response = ();
|
||||
type Error = io::Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<(), io::Error>>>>;
|
||||
|
@ -43,7 +47,7 @@ impl Service for WsService {
|
|||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&self, (req, state, mut codec): Self::Request) -> Self::Future {
|
||||
fn call(&self, (req, io, state, mut codec): Self::Request) -> Self::Future {
|
||||
let fut = async move {
|
||||
let res = handshake(req.head()).unwrap().message_body(());
|
||||
|
||||
|
@ -51,7 +55,7 @@ impl Service for WsService {
|
|||
.write_item((res, body::BodySize::None).into(), &mut codec)
|
||||
.unwrap();
|
||||
|
||||
Dispatcher::from_state(ws::Codec::new(), state, service, Timer::default())
|
||||
Dispatcher::new(io, ws::Codec::new(), state, service, Timer::default())
|
||||
.await
|
||||
.map_err(|_| panic!())
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue