Refactor h1 dispatcher (#38)

* refactor h1 dispatcher

* Rename FrameReadTask/FramedWriteTask to ReadTask/WriteTask

* Make Encoder and Decoder methods immutable
This commit is contained in:
Nikolay Kim 2021-01-23 19:44:56 +06:00 committed by GitHub
parent 20f38402ab
commit 6c250c9a4d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 1585 additions and 2292 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.3.0] - 2021-01-23
* Make Encoder and Decoder methods immutable
## [0.2.2] - 2021-01-21
* Flush underlying io stream

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-codec"
version = "0.2.2"
version = "0.3.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
@ -18,12 +18,12 @@ path = "src/lib.rs"
[dependencies]
bitflags = "1.2.1"
bytes = "0.5.6"
either = "1.5.3"
either = "1.6.1"
futures-core = "0.3.12"
futures-sink = "0.3.12"
log = "0.4"
tokio = { version = "0.2.6", default-features=false }
[dev-dependencies]
ntex = "0.2.0-b.2"
ntex = "0.2.0-b.3"
futures = "0.3.12"

View file

@ -14,7 +14,7 @@ impl Encoder for BytesCodec {
type Error = io::Error;
#[inline]
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.extend_from_slice(item.bytes());
Ok(())
}
@ -24,7 +24,7 @@ impl Decoder for BytesCodec {
type Item = BytesMut;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.is_empty() {
Ok(None)
} else {

View file

@ -1,4 +1,5 @@
use bytes::BytesMut;
use std::rc::Rc;
/// Decoding of frames via buffers.
pub trait Decoder {
@ -13,7 +14,7 @@ pub trait Decoder {
type Error: std::fmt::Debug;
/// Attempts to decode a frame from the provided buffer of bytes.
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error>;
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error>;
/// A default method available to be called when there are no more bytes
/// available to be read from the underlying I/O.
@ -22,13 +23,26 @@ pub trait Decoder {
/// `Ok(None)` is returned while there is unconsumed data in `buf`.
/// Typically this doesn't need to be implemented unless the framing
/// protocol differs near the end of the stream.
fn decode_eof(
&mut self,
buf: &mut BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
fn decode_eof(&self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.decode(buf)? {
Some(frame) => Ok(Some(frame)),
None => Ok(None),
}
}
}
impl<T> Decoder for Rc<T>
where
T: Decoder,
{
type Item = T::Item;
type Error = T::Error;
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
(**self).decode(src)
}
fn decode_eof(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
(**self).decode_eof(src)
}
}

View file

@ -1,4 +1,5 @@
use bytes::BytesMut;
use std::rc::Rc;
/// Trait of helper objects to write out messages as bytes.
pub trait Encoder {
@ -9,9 +10,17 @@ pub trait Encoder {
type Error: std::fmt::Debug;
/// Encodes a frame into the buffer provided.
fn encode(
&mut self,
item: Self::Item,
dst: &mut BytesMut,
) -> Result<(), Self::Error>;
fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error>;
}
impl<T> Encoder for Rc<T>
where
T: Encoder,
{
type Item = T::Item;
type Error = T::Error;
fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
(**self).encode(item, dst)
}
}

View file

@ -1,5 +1,15 @@
# Changes
## [0.2.0-b.4] - 2021-01-xx
* http: Refactor h1 dispatcher
* http: Remove generic type from `Request`
* http: Remove generic type from `Payload`
* Rename FrameReadTask/FramedWriteTask to ReadTask/WriteTask
## [0.2.0-b.3] - 2021-01-21
* Allow to use framed write task for io flushing

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.2.0-b.3"
version = "0.2.0-b.4"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"
@ -36,7 +36,7 @@ compress = ["flate2", "brotli2"]
cookie = ["coo-kie", "coo-kie/percent-encode"]
[dependencies]
ntex-codec = "0.2.2"
ntex-codec = "0.3.0-b.1"
ntex-rt = "0.1.1"
ntex-rt-macros = "0.1"
ntex-router = "0.3.8"
@ -49,7 +49,7 @@ bitflags = "1.2.1"
bytes = "0.5.6"
bytestring = "0.1.5"
derive_more = "0.99.11"
either = "1.5.3"
either = "1.6.1"
encoding_rs = "0.8.26"
futures = "0.3.12"
ahash = "0.6.3"
@ -104,3 +104,6 @@ serde_derive = "1.0"
open-ssl = { version="0.10", package = "openssl" }
rust-tls = { version = "0.19.0", package="rustls", features = ["dangerous_configuration"] }
webpki = "0.21.2"
[patch.crates-io]
ntex = { path = "../ntex-codec" }

View file

@ -9,7 +9,7 @@ use either::Either;
use futures::FutureExt;
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
use crate::framed::{DispatcherItem, FramedReadTask, FramedWriteTask, State, Timer};
use crate::framed::{DispatchItem, ReadTask, State, Timer, WriteTask};
use crate::service::{IntoService, Service};
type Response<U> = <U as Encoder>::Item;
@ -19,7 +19,7 @@ pin_project_lite::pin_project! {
/// and pass then to the service.
pub struct Dispatcher<S, U>
where
S: Service<Request = DispatcherItem<U>, Response = Option<Response<U>>>,
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Encoder,
@ -27,12 +27,7 @@ pin_project_lite::pin_project! {
<U as Encoder>::Item: 'static,
{
service: S,
state: State<U>,
inner: Rc<DispatcherInner<S, U>>,
st: DispatcherState,
timer: Timer<U>,
updated: Instant,
keepalive_timeout: u16,
inner: DispatcherInner<S, U>,
#[pin]
response: Option<S::Future>,
}
@ -40,12 +35,23 @@ pin_project_lite::pin_project! {
struct DispatcherInner<S, U>
where
S: Service<Request = DispatcherItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Encoder + Decoder,
<U as Encoder>::Item: 'static,
{
st: DispatcherState,
state: State,
timer: Timer,
updated: Instant,
keepalive_timeout: u16,
shared: Rc<DispatcherShared<S, U>>,
}
struct DispatcherShared<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Encoder + Decoder,
{
codec: U,
error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
inflight: Cell<usize>,
}
@ -73,13 +79,13 @@ impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
}
impl<E1, E2: fmt::Debug> DispatcherError<E1, E2> {
fn convert<U>(self) -> Option<DispatcherItem<U>>
fn convert<U>(self) -> Option<DispatchItem<U>>
where
U: Encoder<Error = E2> + Decoder,
{
match self {
DispatcherError::KeepAlive => Some(DispatcherItem::KeepAliveTimeout),
DispatcherError::Encoder(err) => Some(DispatcherItem::EncoderError(err)),
DispatcherError::KeepAlive => Some(DispatchItem::KeepAliveTimeout),
DispatcherError::Encoder(err) => Some(DispatchItem::EncoderError(err)),
DispatcherError::Service(_) => None,
}
}
@ -87,46 +93,59 @@ impl<E1, E2: fmt::Debug> DispatcherError<E1, E2> {
impl<S, U> Dispatcher<S, U>
where
S: Service<Request = DispatcherItem<U>, Response = Option<Response<U>>> + 'static,
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + 'static,
<U as Encoder>::Item: 'static,
{
/// Construct new `Dispatcher` instance.
pub fn new<T, F: IntoService<S>>(
io: T,
state: State<U>,
codec: U,
state: State,
service: F,
timer: Timer<U>,
timer: Timer,
) -> Self
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
let io = Rc::new(RefCell::new(io));
// start support tasks
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
crate::rt::spawn(WriteTask::new(io, state.clone()));
Self::from_state(codec, state, service, timer)
}
/// Construct new `Dispatcher` instance.
pub fn from_state<F: IntoService<S>>(
codec: U,
state: State,
service: F,
timer: Timer,
) -> Self {
let updated = timer.now();
let keepalive_timeout: u16 = 30;
let io = Rc::new(RefCell::new(io));
// register keepalive timer
let expire = updated + Duration::from_secs(keepalive_timeout as u64);
timer.register(expire, expire, &state);
// start support tasks
crate::rt::spawn(FramedReadTask::new(io.clone(), state.clone()));
crate::rt::spawn(FramedWriteTask::new(io, state.clone()));
let inner = Rc::new(DispatcherInner {
error: Cell::new(None),
inflight: Cell::new(0),
});
Dispatcher {
st: DispatcherState::Processing,
service: service.into_service(),
response: None,
state,
inner,
timer,
updated,
keepalive_timeout,
inner: DispatcherInner {
state,
timer,
updated,
keepalive_timeout,
st: DispatcherState::Processing,
shared: Rc::new(DispatcherShared {
codec,
error: Cell::new(None),
inflight: Cell::new(0),
}),
},
}
}
@ -137,15 +156,15 @@ where
/// By default keep-alive timeout is set to 30 seconds.
pub fn keepalive_timeout(mut self, timeout: u16) -> Self {
// register keepalive timer
let prev = self.updated + Duration::from_secs(self.keepalive_timeout as u64);
let prev = self.inner.updated
+ Duration::from_secs(self.inner.keepalive_timeout as u64);
if timeout == 0 {
self.timer.unregister(prev, &self.state);
self.inner.timer.unregister(prev, &self.inner.state);
} else {
let expire = self.updated + Duration::from_secs(timeout as u64);
self.timer.register(expire, prev, &self.state);
let expire = self.inner.updated + Duration::from_secs(timeout as u64);
self.inner.timer.register(expire, prev, &self.inner.state);
}
self.keepalive_timeout = timeout;
self.inner.keepalive_timeout = timeout;
self
}
@ -159,14 +178,14 @@ where
///
/// By default disconnect timeout is set to 1 seconds.
pub fn disconnect_timeout(self, val: u16) -> Self {
self.state.set_disconnect_timeout(val);
self.inner.state.set_disconnect_timeout(val);
self
}
}
impl<S, U> DispatcherInner<S, U>
impl<S, U> DispatcherShared<S, U>
where
S: Service<Request = DispatcherItem<U>, Response = Option<Response<U>>>,
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Encoder + Decoder,
@ -175,11 +194,11 @@ where
fn handle_result(
&self,
item: Result<S::Response, S::Error>,
state: &State<U>,
state: &State,
wake: bool,
) {
self.inflight.set(self.inflight.get() - 1);
if let Err(err) = state.write_result(item) {
if let Err(err) = state.write_result(item, &self.codec) {
self.error.set(Some(err.into()));
}
@ -189,9 +208,60 @@ where
}
}
impl<S, U> DispatcherInner<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Decoder + Encoder,
{
fn take_error(&self) -> Option<DispatchItem<U>> {
// check for errors
self.shared
.error
.take()
.and_then(|err| err.convert())
.or_else(|| self.state.take_io_error().map(DispatchItem::IoError))
}
/// check keepalive timeout
fn check_keepalive(&self) {
if self.state.is_keepalive() {
log::trace!("keepalive timeout");
if let Some(err) = self.shared.error.take() {
self.shared.error.set(Some(err));
} else {
self.shared.error.set(Some(DispatcherError::KeepAlive));
}
self.state.dsp_mark_stopped();
}
}
/// update keep-alive timer
fn update_keepalive(&mut self) {
if self.keepalive_timeout != 0 {
let updated = self.timer.now();
if updated != self.updated {
let ka = Duration::from_secs(self.keepalive_timeout as u64);
self.timer
.register(updated + ka, self.updated + ka, &self.state);
self.updated = updated;
}
}
}
/// unregister keep-alive timer
fn unregister_keepalive(&self) {
if self.keepalive_timeout != 0 {
self.timer.unregister(
self.updated + Duration::from_secs(self.keepalive_timeout as u64),
&self.state,
);
}
}
}
impl<S, U> Future for Dispatcher<S, U>
where
S: Service<Request = DispatcherItem<U>, Response = Option<Response<U>>> + 'static,
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + 'static,
<U as Encoder>::Item: 'static,
{
@ -200,125 +270,70 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
// log::trace!("IO-DISP poll :{:?}:", this.st);
// handle service response future
if let Some(fut) = this.response.as_mut().as_pin_mut() {
match fut.poll(cx) {
Poll::Pending => (),
Poll::Ready(item) => {
this.inner.handle_result(item, this.state, false);
this.inner
.shared
.handle_result(item, &this.inner.state, false);
this.response.set(None);
}
}
}
match this.st {
match this.inner.st {
DispatcherState::Processing => {
loop {
// log::trace!("IO-DISP state :{:?}:", this.state.get_flags());
match this.service.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
let mut retry = false;
// service is ready, wake io read task
this.state.dsp_restart_read_task();
this.inner.state.dsp_restart_read_task();
let item = if this.state.is_dsp_stopped() {
// check keepalive timeout
this.inner.check_keepalive();
let item = if this.inner.state.is_dsp_stopped() {
log::trace!("dispatcher is instructed to stop");
// check keepalive timeout
if this.state.is_keepalive_err() {
if let Some(err) = this.inner.error.take() {
this.inner.error.set(Some(err));
} else {
this.inner
.error
.set(Some(DispatcherError::KeepAlive));
}
} else if *this.keepalive_timeout != 0 {
// unregister keep-alive timer
this.timer.unregister(
*this.updated
+ Duration::from_secs(
*this.keepalive_timeout as u64,
),
this.state,
);
}
// unregister keep-alive timer
this.inner.unregister_keepalive();
// check for errors
let item = this
.inner
.error
.take()
.and_then(|err| err.convert())
.or_else(|| {
this.state
.take_io_error()
.map(DispatcherItem::IoError)
});
*this.st = DispatcherState::Stop;
retry = true;
item
this.inner.st = DispatcherState::Stop;
this.inner.take_error()
} else {
// decode incoming bytes stream
if this.state.is_read_ready() {
// this.state.with_read_buf(|buf| {
// log::trace!(
// "attempt to decode frame, buffer size is {:?}",
// buf
// );
// });
match this.state.decode_item() {
if this.inner.state.is_read_ready() {
let item = this
.inner
.state
.decode_item(&this.inner.shared.codec);
match item {
Ok(Some(el)) => {
// update keep-alive timer
if *this.keepalive_timeout != 0 {
let updated = this.timer.now();
if updated != *this.updated {
let ka = Duration::from_secs(
*this.keepalive_timeout as u64,
);
this.timer.register(
updated + ka,
*this.updated + ka,
this.state,
);
*this.updated = updated;
}
}
Some(DispatcherItem::Item(el))
this.inner.update_keepalive();
Some(DispatchItem::Item(el))
}
Ok(None) => {
// log::trace!("not enough data to decode next frame, register dispatch task");
this.state.dsp_read_more_data(cx.waker());
log::trace!("not enough data to decode next frame, register dispatch task");
this.inner
.state
.dsp_read_more_data(cx.waker());
return Poll::Pending;
}
Err(err) => {
retry = true;
*this.st = DispatcherState::Stop;
// unregister keep-alive timer
if *this.keepalive_timeout != 0 {
this.timer.unregister(
*this.updated
+ Duration::from_secs(
*this.keepalive_timeout
as u64,
),
this.state,
);
}
Some(DispatcherItem::DecoderError(err))
this.inner.st = DispatcherState::Stop;
this.inner.unregister_keepalive();
Some(DispatchItem::DecoderError(err))
}
}
} else {
this.state.dsp_register_task(cx.waker());
this.inner.state.dsp_register_task(cx.waker());
return Poll::Pending;
}
};
@ -336,24 +351,33 @@ where
.poll(cx);
if let Poll::Ready(res) = res {
if let Err(err) = this.state.write_result(res) {
this.inner.error.set(Some(err.into()));
if let Err(err) = this
.inner
.state
.write_result(res, &this.inner.shared.codec)
{
this.inner
.shared
.error
.set(Some(err.into()));
}
this.response.set(None);
} else {
this.inner
.shared
.inflight
.set(this.inner.inflight.get() + 1);
.set(this.inner.shared.inflight.get() + 1);
}
} else {
this.inner
.shared
.inflight
.set(this.inner.inflight.get() + 1);
let st = this.state.clone();
let inner = this.inner.clone();
.set(this.inner.shared.inflight.get() + 1);
let st = this.inner.state.clone();
let shared = this.inner.shared.clone();
crate::rt::spawn(this.service.call(item).map(
move |item| {
inner.handle_result(item, &st, true);
shared.handle_result(item, &st, true);
},
));
}
@ -367,27 +391,19 @@ where
Poll::Pending => {
// pause io read task
log::trace!("service is not ready, register dispatch task");
this.state.dsp_service_not_ready(cx.waker());
this.inner.state.dsp_service_not_ready(cx.waker());
return Poll::Pending;
}
Poll::Ready(Err(err)) => {
// handle service readiness error
log::trace!("service readiness check failed, stopping");
// service readiness error
*this.st = DispatcherState::Stop;
this.state.dsp_mark_stopped();
this.inner.error.set(Some(DispatcherError::Service(err)));
// unregister keep-alive timer
if *this.keepalive_timeout != 0 {
this.timer.unregister(
*this.updated
+ Duration::from_secs(
*this.keepalive_timeout as u64,
),
this.state,
);
}
this.inner.st = DispatcherState::Stop;
this.inner.state.dsp_mark_stopped();
this.inner
.shared
.error
.set(Some(DispatcherError::Service(err)));
this.inner.unregister_keepalive();
return self.poll(cx);
}
}
@ -398,18 +414,18 @@ where
// service may relay on poll_ready for response results
let _ = this.service.poll_ready(cx);
if this.inner.inflight.get() == 0 {
this.state.shutdown_io();
*this.st = DispatcherState::Shutdown;
if this.inner.shared.inflight.get() == 0 {
this.inner.state.shutdown_io();
this.inner.st = DispatcherState::Shutdown;
self.poll(cx)
} else {
this.state.dsp_register_task(cx.waker());
this.inner.state.dsp_register_task(cx.waker());
Poll::Pending
}
}
// shutdown service
DispatcherState::Shutdown => {
let err = this.inner.error.take();
let err = this.inner.shared.error.take();
if this.service.poll_shutdown(cx, err.is_some()).is_ready() {
log::trace!("service shutdown is completed, stop");
@ -420,7 +436,7 @@ where
Ok(())
})
} else {
this.inner.error.set(err);
this.inner.shared.error.set(err);
Poll::Pending
}
}
@ -441,7 +457,7 @@ mod tests {
impl<S, U> Dispatcher<S, U>
where
S: Service<Request = DispatcherItem<U>, Response = Option<Response<U>>>,
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Decoder + Encoder + 'static,
@ -452,33 +468,36 @@ mod tests {
io: T,
codec: U,
service: F,
) -> (Self, State<U>)
) -> (Self, State)
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
let timer = Timer::with(Duration::from_secs(1));
let timer = Timer::default();
let keepalive_timeout = 30;
let updated = timer.now();
let state = State::new(codec);
let state = State::new();
let io = Rc::new(RefCell::new(io));
let inner = Rc::new(DispatcherInner {
let shared = Rc::new(DispatcherShared {
codec: codec,
error: Cell::new(None),
inflight: Cell::new(0),
});
crate::rt::spawn(FramedReadTask::new(io.clone(), state.clone()));
crate::rt::spawn(FramedWriteTask::new(io.clone(), state.clone()));
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
crate::rt::spawn(WriteTask::new(io.clone(), state.clone()));
(
Dispatcher {
service: service.into_service(),
state: state.clone(),
st: DispatcherState::Processing,
response: None,
timer,
updated,
keepalive_timeout,
inner,
inner: DispatcherInner {
shared,
timer,
updated,
keepalive_timeout,
state: state.clone(),
st: DispatcherState::Processing,
},
},
state,
)
@ -494,9 +513,9 @@ mod tests {
let (disp, _) = Dispatcher::debug(
server,
BytesCodec,
crate::fn_service(|msg: DispatcherItem<BytesCodec>| async move {
crate::fn_service(|msg: DispatchItem<BytesCodec>| async move {
delay_for(Duration::from_millis(50)).await;
if let DispatcherItem::Item(msg) = msg {
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
@ -521,8 +540,8 @@ mod tests {
let (disp, st) = Dispatcher::debug(
server,
BytesCodec,
crate::fn_service(|msg: DispatcherItem<BytesCodec>| async move {
if let DispatcherItem::Item(msg) = msg {
crate::fn_service(|msg: DispatchItem<BytesCodec>| async move {
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
@ -534,7 +553,9 @@ mod tests {
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
assert!(st.write_item(Bytes::from_static(b"test")).is_ok());
assert!(st
.write_item(Bytes::from_static(b"test"), &mut BytesCodec)
.is_ok());
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test"));
@ -552,14 +573,17 @@ mod tests {
let (disp, state) = Dispatcher::debug(
server,
BytesCodec,
crate::fn_service(|_: DispatcherItem<BytesCodec>| async move {
crate::fn_service(|_: DispatchItem<BytesCodec>| async move {
Err::<Option<Bytes>, _>(())
}),
);
crate::rt::spawn(disp.map(|_| ()));
state
.write_item(Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"))
.write_item(
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
&mut BytesCodec,
)
.unwrap();
let buf = client.read_any();

View file

@ -1,3 +1,5 @@
use std::{fmt, io};
mod dispatcher;
mod read;
mod state;
@ -5,7 +7,48 @@ mod time;
mod write;
pub use self::dispatcher::Dispatcher;
pub use self::read::FramedReadTask;
pub use self::state::{DispatcherItem, State};
pub use self::read::ReadTask;
pub use self::state::State;
pub use self::time::Timer;
pub use self::write::FramedWriteTask;
pub use self::write::WriteTask;
use crate::codec::{Decoder, Encoder};
/// Framed transport item
pub enum DispatchItem<U: Encoder + Decoder> {
Item(<U as Decoder>::Item),
/// Keep alive timeout
KeepAliveTimeout,
/// Decoder parse error
DecoderError(<U as Decoder>::Error),
/// Encoder parse error
EncoderError(<U as Encoder>::Error),
/// Unexpected io error
IoError(io::Error),
}
impl<U> fmt::Debug for DispatchItem<U>
where
U: Encoder + Decoder,
<U as Decoder>::Item: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
DispatchItem::Item(ref item) => {
write!(fmt, "DispatchItem::Item({:?})", item)
}
DispatchItem::KeepAliveTimeout => {
write!(fmt, "DispatchItem::KeepAliveTimeout")
}
DispatchItem::EncoderError(ref e) => {
write!(fmt, "DispatchItem::EncoderError({:?})", e)
}
DispatchItem::DecoderError(ref e) => {
write!(fmt, "DispatchItem::DecoderError({:?})", e)
}
DispatchItem::IoError(ref e) => {
write!(fmt, "DispatchItem::IoError({:?})", e)
}
}
}
}

View file

@ -11,25 +11,25 @@ const LW: usize = 1024;
const HW: usize = 8 * 1024;
/// Read io task
pub struct FramedReadTask<T, U>
pub struct ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
io: Rc<RefCell<T>>,
state: State<U>,
state: State,
}
impl<T, U> FramedReadTask<T, U>
impl<T> ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
/// Create new read io task
pub fn new(io: Rc<RefCell<T>>, state: State<U>) -> Self {
pub fn new(io: Rc<RefCell<T>>, state: State) -> Self {
Self { io, state }
}
}
impl<T, U> Future for FramedReadTask<T, U>
impl<T> Future for ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{

View file

@ -1,6 +1,6 @@
//! Framed transport dispatcher
use std::task::{Context, Poll, Waker};
use std::{cell::Cell, cell::RefCell, fmt, hash, io, mem, pin::Pin, rc::Rc};
use std::{cell::Cell, cell::RefCell, hash, io, mem, pin::Pin, rc::Rc};
use bytes::BytesMut;
use either::Either;
@ -10,13 +10,10 @@ use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}
use crate::framed::write::flush;
use crate::task::LocalWaker;
type Request<U> = <U as Decoder>::Item;
type Response<U> = <U as Encoder>::Item;
const HW: usize = 8 * 1024;
bitflags::bitflags! {
pub(crate) struct Flags: u8 {
pub struct Flags: u8 {
const DSP_STOP = 0b0000_0001;
const DSP_KEEPALIVE = 0b0000_0010;
@ -35,49 +32,9 @@ bitflags::bitflags! {
}
}
/// Framed transport item
pub enum DispatcherItem<U: Encoder + Decoder> {
Item(Request<U>),
/// Keep alive timeout
KeepAliveTimeout,
/// Decoder parse error
DecoderError(<U as Decoder>::Error),
/// Encoder parse error
EncoderError(<U as Encoder>::Error),
/// Unexpected io error
IoError(io::Error),
}
pub struct State(Rc<IoStateInner>);
impl<U> fmt::Debug for DispatcherItem<U>
where
U: Encoder + Decoder,
<U as Decoder>::Item: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
DispatcherItem::Item(ref item) => {
write!(fmt, "DispatcherItem::Item({:?})", item)
}
DispatcherItem::KeepAliveTimeout => {
write!(fmt, "DispatcherItem::KeepAliveTimeout")
}
DispatcherItem::EncoderError(ref e) => {
write!(fmt, "DispatcherItem::EncoderError({:?})", e)
}
DispatcherItem::DecoderError(ref e) => {
write!(fmt, "DispatcherItem::DecoderError({:?})", e)
}
DispatcherItem::IoError(ref e) => {
write!(fmt, "DispatcherItem::IoError({:?})", e)
}
}
}
}
pub struct State<U>(Rc<IoStateInner<U>>);
pub(crate) struct IoStateInner<U> {
codec: RefCell<U>,
pub(crate) struct IoStateInner {
flags: Cell<Flags>,
error: Cell<Option<io::Error>>,
disconnect_timeout: Cell<u16>,
@ -88,39 +45,29 @@ pub(crate) struct IoStateInner<U> {
write_buf: RefCell<BytesMut>,
}
impl<U> State<U> {
pub(crate) fn keepalive_timeout(&self) {
let state = self.0.as_ref();
let mut flags = state.flags.get();
flags.insert(Flags::DSP_STOP | Flags::DSP_KEEPALIVE);
state.flags.set(flags);
state.dispatch_task.wake();
}
}
impl<U> Clone for State<U> {
impl Clone for State {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<U> Eq for State<U> {}
impl Eq for State {}
impl<U> PartialEq for State<U> {
impl PartialEq for State {
fn eq(&self, other: &Self) -> bool {
Rc::as_ptr(&self.0) == Rc::as_ptr(&other.0)
}
}
impl<U> hash::Hash for State<U> {
impl hash::Hash for State {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
Rc::as_ptr(&self.0).hash(state);
}
}
impl<U> State<U> {
impl State {
/// Create `State` instance
pub fn new(codec: U) -> Self {
pub fn new() -> Self {
State(Rc::new(IoStateInner {
flags: Cell::new(Flags::empty()),
error: Cell::new(None),
@ -128,14 +75,13 @@ impl<U> State<U> {
dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(),
write_task: LocalWaker::new(),
codec: RefCell::new(codec),
read_buf: RefCell::new(BytesMut::new()),
write_buf: RefCell::new(BytesMut::new()),
}))
}
/// Create `State` from Framed
pub fn from_framed<Io>(framed: Framed<Io, U>) -> (Io, Self) {
pub fn from_framed<Io, U>(framed: Framed<Io, U>) -> (Io, U, Self) {
let parts = framed.into_parts();
let state = State(Rc::new(IoStateInner {
@ -145,30 +91,38 @@ impl<U> State<U> {
dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(),
write_task: LocalWaker::new(),
codec: RefCell::new(parts.codec),
read_buf: RefCell::new(parts.read_buf),
write_buf: RefCell::new(parts.write_buf),
}));
(parts.io, state)
(parts.io, parts.codec, state)
}
/// Convert state to a Framed instance
pub fn into_framed<Io>(self, io: Io) -> Result<Framed<Io, U>, Io> {
match Rc::try_unwrap(self.0) {
Ok(inner) => {
let mut parts = FramedParts::new(io, inner.codec.into_inner());
parts.read_buf = inner.read_buf.into_inner();
parts.write_buf = inner.write_buf.into_inner();
Ok(Framed::from_parts(parts))
}
Err(_) => Err(io),
}
pub fn into_framed<Io, U>(self, io: Io, codec: U) -> Framed<Io, U> {
let mut parts = FramedParts::new(io, codec);
parts.read_buf = mem::take(&mut self.0.read_buf.borrow_mut());
parts.write_buf = mem::take(&mut self.0.write_buf.borrow_mut());
Framed::from_parts(parts)
}
pub(crate) fn keepalive_timeout(&self) {
let state = self.0.as_ref();
let mut flags = state.flags.get();
flags.insert(Flags::DSP_KEEPALIVE);
state.flags.set(flags);
state.dispatch_task.wake();
}
pub(super) fn disconnect_timeout(&self) -> u16 {
self.0.disconnect_timeout.get()
}
#[inline]
/// Get current state flags
pub fn flags(&self) -> Flags {
self.0.flags.get()
}
#[inline]
/// Set disconnecto timeout
pub fn set_disconnect_timeout(&self, timeout: u16) {
@ -212,10 +166,18 @@ impl<U> State<U> {
#[inline]
/// Check if keep-alive timeout occured
pub fn is_keepalive_err(&self) -> bool {
pub fn is_keepalive(&self) -> bool {
self.0.flags.get().contains(Flags::DSP_KEEPALIVE)
}
#[inline]
/// Reset keep-alive error
pub fn reset_keepalive(&self) {
let mut flags = self.0.flags.get();
flags.remove(Flags::DSP_KEEPALIVE);
self.0.flags.set(flags);
}
#[inline]
/// Check is dispatcher marked stopped
pub fn is_dsp_stopped(&self) -> bool {
@ -377,6 +339,7 @@ impl<U> State<U> {
}
#[inline]
/// Get mut access to read buffer
pub fn with_read_buf<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
@ -385,6 +348,7 @@ impl<U> State<U> {
}
#[inline]
/// Get mut access to write buffer
pub fn with_write_buf<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
@ -393,56 +357,31 @@ impl<U> State<U> {
}
}
impl<U> State<U>
where
U: Encoder + Decoder,
{
impl State {
#[inline]
/// Consume the `IoState`, returning `IoState` with different codec.
pub fn map_codec<F, U2>(self, f: F) -> State<U2>
/// Attempts to decode a frame from the read buffer.
pub fn decode_item<U>(
&self,
codec: &U,
) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error>
where
F: Fn(&U) -> U2,
U2: Encoder + Decoder,
U: Decoder,
{
let st = self.0.as_ref();
let codec = f(&st.codec.borrow());
State(Rc::new(IoStateInner {
codec: RefCell::new(codec),
flags: Cell::new(st.flags.get()),
error: Cell::new(st.error.take()),
disconnect_timeout: Cell::new(st.disconnect_timeout.get()),
dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(),
write_task: LocalWaker::new(),
read_buf: RefCell::new(mem::take(&mut st.read_buf.borrow_mut())),
write_buf: RefCell::new(mem::take(&mut st.write_buf.borrow_mut())),
}))
codec.decode(&mut self.0.read_buf.borrow_mut())
}
#[inline]
pub fn with_codec<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut U) -> R,
{
f(&mut *self.0.codec.borrow_mut())
}
#[inline]
pub async fn next<T>(
pub async fn next<T, U>(
&self,
io: &mut T,
) -> Result<Option<<U as Decoder>::Item>, Either<<U as Decoder>::Error, io::Error>>
codec: &mut U,
) -> Result<Option<U::Item>, Either<U::Error, io::Error>>
where
T: AsyncRead + AsyncWrite + Unpin,
U: Decoder,
{
loop {
let item = {
self.0
.codec
.borrow_mut()
.decode(&mut self.0.read_buf.borrow_mut())
};
let item = codec.decode(&mut self.0.read_buf.borrow_mut());
return match item {
Ok(Some(el)) => Ok(Some(el)),
Ok(None) => {
@ -468,18 +407,17 @@ where
}
#[inline]
pub fn poll_next<T>(
pub fn poll_next<T, U>(
&self,
io: &mut T,
codec: &mut U,
cx: &mut Context<'_>,
) -> Poll<
Result<Option<<U as Decoder>::Item>, Either<<U as Decoder>::Error, io::Error>>,
>
) -> Poll<Result<Option<U::Item>, Either<U::Error, io::Error>>>
where
T: AsyncRead + AsyncWrite + Unpin,
U: Decoder,
{
let mut buf = self.0.read_buf.borrow_mut();
let mut codec = self.0.codec.borrow_mut();
loop {
return match codec.decode(&mut buf) {
@ -502,17 +440,18 @@ where
}
#[inline]
pub async fn send<T>(
/// Encode item, send to a peer and flush
pub async fn send<T, U>(
&self,
io: &mut T,
item: <U as Encoder>::Item,
) -> Result<(), Either<<U as Encoder>::Error, io::Error>>
codec: &U,
item: U::Item,
) -> Result<(), Either<U::Error, io::Error>>
where
T: AsyncRead + AsyncWrite + Unpin,
U: Encoder,
{
self.0
.codec
.borrow_mut()
codec
.encode(item, &mut self.0.write_buf.borrow_mut())
.map_err(Either::Left)?;
@ -525,25 +464,18 @@ where
})
}
#[inline]
/// Attempts to decode a frame from the read buffer.
pub fn decode_item(
&self,
) -> Result<Option<<U as Decoder>::Item>, <U as Decoder>::Error> {
self.0
.codec
.borrow_mut()
.decode(&mut self.0.read_buf.borrow_mut())
}
#[inline]
/// Write item to a buf and wake up io task
///
/// Returns state of write buffer state, false is returned if write buffer if full.
pub fn write_item(
pub fn write_item<U>(
&self,
item: <U as Encoder>::Item,
) -> Result<bool, <U as Encoder>::Error> {
item: U::Item,
codec: &U,
) -> Result<bool, <U as Encoder>::Error>
where
U: Encoder,
{
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
@ -551,10 +483,7 @@ where
let is_write_sleep = write_buf.is_empty();
// encode item and wake write task
let res = self
.0
.codec
.borrow_mut()
let res = codec
.encode(item, &mut *write_buf)
.map(|_| write_buf.len() < HW);
if res.is_ok() && is_write_sleep {
@ -568,10 +497,14 @@ where
#[inline]
/// Write item to a buf and wake up io task
pub fn write_result<E>(
pub fn write_result<U, E>(
&self,
item: Result<Option<Response<U>>, E>,
) -> Result<bool, Either<E, <U as Encoder>::Error>> {
item: Result<Option<U::Item>, E>,
codec: &U,
) -> Result<bool, Either<E, U::Error>>
where
U: Encoder,
{
let mut flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::ST_DSP_ERR) {
@ -581,9 +514,7 @@ where
let is_write_sleep = write_buf.is_empty();
// encode item
if let Err(err) =
self.0.codec.borrow_mut().encode(item, &mut write_buf)
{
if let Err(err) = codec.encode(item, &mut write_buf) {
log::trace!("Codec encoder error: {:?}", err);
flags.insert(Flags::DSP_STOP | Flags::ST_DSP_ERR);
self.0.flags.set(flags);

View file

@ -6,15 +6,15 @@ use crate::framed::State;
use crate::rt::time::delay_for;
use crate::HashSet;
pub struct Timer<U>(Rc<RefCell<Inner<U>>>);
pub struct Timer(Rc<RefCell<Inner>>);
struct Inner<U> {
struct Inner {
resolution: Duration,
current: Option<Instant>,
notifications: BTreeMap<Instant, HashSet<State<U>>>,
notifications: BTreeMap<Instant, HashSet<State>>,
}
impl<U> Inner<U> {
impl Inner {
fn new(resolution: Duration) -> Self {
Inner {
resolution,
@ -23,7 +23,7 @@ impl<U> Inner<U> {
}
}
fn unregister(&mut self, expire: Instant, state: &State<U>) {
fn unregister(&mut self, expire: Instant, state: &State) {
if let Some(ref mut states) = self.notifications.get_mut(&expire) {
states.remove(state);
if states.is_empty() {
@ -33,18 +33,24 @@ impl<U> Inner<U> {
}
}
impl<U> Clone for Timer<U> {
impl Clone for Timer {
fn clone(&self) -> Self {
Timer(self.0.clone())
}
}
impl<U: 'static> Timer<U> {
pub fn with(resolution: Duration) -> Timer<U> {
impl Default for Timer {
fn default() -> Self {
Timer::with(Duration::from_secs(1))
}
}
impl Timer {
pub fn with(resolution: Duration) -> Timer {
Timer(Rc::new(RefCell::new(Inner::new(resolution))))
}
pub fn register(&self, expire: Instant, previous: Instant, state: &State<U>) {
pub fn register(&self, expire: Instant, previous: Instant, state: &State) {
{
let mut inner = self.0.borrow_mut();
@ -59,7 +65,7 @@ impl<U: 'static> Timer<U> {
let _ = self.now();
}
pub fn unregister(&self, expire: Instant, state: &State<U>) {
pub fn unregister(&self, expire: Instant, state: &State) {
self.0.borrow_mut().unregister(expire, state);
}

View file

@ -3,7 +3,7 @@ use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc, time::Duration};
use bytes::{Buf, BytesMut};
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::framed::State;
use crate::rt::time::{delay_for, Delay};
@ -23,25 +23,21 @@ enum Shutdown {
}
/// Write io task
pub struct FramedWriteTask<T, U>
pub struct WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
U: Encoder + Decoder,
<U as Encoder>::Item: 'static,
{
st: IoWriteState,
io: Rc<RefCell<T>>,
state: State<U>,
state: State,
}
impl<T, U> FramedWriteTask<T, U>
impl<T> WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
U: Encoder + Decoder,
<U as Encoder>::Item: 'static,
{
/// Create new write io task
pub fn new(io: Rc<RefCell<T>>, state: State<U>) -> Self {
pub fn new(io: Rc<RefCell<T>>, state: State) -> Self {
Self {
io,
state,
@ -50,7 +46,7 @@ where
}
/// Shutdown io stream
pub fn shutdown(io: Rc<RefCell<T>>, state: State<U>) -> Self {
pub fn shutdown(io: Rc<RefCell<T>>, state: State) -> Self {
let disconnect_timeout = state.disconnect_timeout() as u64;
let st = IoWriteState::Shutdown(
if disconnect_timeout != 0 {
@ -65,11 +61,9 @@ where
}
}
impl<T, U> Future for FramedWriteTask<T, U>
impl<T> Future for WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
U: Encoder + Decoder,
<U as Encoder>::Item: 'static,
{
type Output = ();
@ -204,8 +198,8 @@ pub(super) fn flush<T>(
where
T: AsyncRead + AsyncWrite + Unpin,
{
// log::trace!("flushing framed transport: {}", len);
let len = buf.len();
log::trace!("flushing framed transport: {}", len);
if len != 0 {
let mut written = 0;

View file

@ -1,8 +1,6 @@
use std::fmt;
use std::marker::PhantomData;
use std::rc::Rc;
use std::{error::Error, fmt, marker::PhantomData, rc::Rc};
use crate::codec::Framed;
use crate::framed::State;
use crate::http::body::MessageBody;
use crate::http::config::{KeepAlive, ServiceConfig};
use crate::http::error::ResponseError;
@ -34,9 +32,9 @@ impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>> {
pub fn new() -> Self {
HttpServiceBuilder {
keep_alive: KeepAlive::Timeout(5),
client_timeout: 3000,
client_disconnect: 3000,
handshake_timeout: 5000,
client_timeout: 3,
client_disconnect: 3,
handshake_timeout: 5,
expect: ExpectHandler,
upgrade: None,
on_connect: None,
@ -53,12 +51,12 @@ where
S::Future: 'static,
<S::Service as Service>::Future: 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<Config = (), Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display,
U: ServiceFactory<Config = (), Request = (Request, State, Codec), Response = ()>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
<U::Service as Service>::Future: 'static,
@ -71,7 +69,7 @@ where
self
}
/// Set server client timeout in milliseconds for first request.
/// Set server client timeout in seconds for first request.
///
/// Defines a timeout for reading client request header. If a client does not transmit
/// the entire set headers within this time, the request is terminated with
@ -80,8 +78,8 @@ where
/// To disable timeout set value to 0.
///
/// By default client timeout is set to 3 seconds.
pub fn client_timeout(mut self, val: u64) -> Self {
self.client_timeout = val;
pub fn client_timeout(mut self, val: u16) -> Self {
self.client_timeout = val as u64;
self
}
@ -93,19 +91,19 @@ where
/// To disable timeout set value to 0.
///
/// By default disconnect timeout is set to 3 seconds.
pub fn disconnect_timeout(mut self, val: u64) -> Self {
self.client_disconnect = val;
pub fn disconnect_timeout(mut self, val: u16) -> Self {
self.client_disconnect = val as u64;
self
}
/// Set server ssl handshake timeout in milliseconds.
/// Set server ssl handshake timeout in seconds.
///
/// Defines a timeout for connection ssl handshake negotiation.
/// To disable timeout set value to 0.
///
/// By default handshake timeout is set to 5 seconds.
pub fn ssl_handshake_timeout(mut self, val: u64) -> Self {
self.handshake_timeout = val;
pub fn ssl_handshake_timeout(mut self, val: u16) -> Self {
self.handshake_timeout = val as u64;
self
}
@ -118,7 +116,7 @@ where
where
F: IntoServiceFactory<X1>,
X1: ServiceFactory<Config = (), Request = Request, Response = Request>,
X1::Error: ResponseError,
X1::Error: ResponseError + 'static,
X1::InitError: fmt::Debug,
X1::Future: 'static,
<X1::Service as Service>::Future: 'static,
@ -144,10 +142,10 @@ where
F: IntoServiceFactory<U1>,
U1: ServiceFactory<
Config = (),
Request = (Request, Framed<T, Codec>),
Request = (Request, State, Codec),
Response = (),
>,
U1::Error: fmt::Display,
U1::Error: fmt::Display + Error + 'static,
U1::InitError: fmt::Debug,
U1::Future: 'static,
<U1::Service as Service>::Future: 'static,

View file

@ -376,10 +376,7 @@ impl ClientRequest {
pub fn freeze(self) -> Result<FrozenClientRequest, FreezeRequestError> {
let slf = match self.prep_for_sending() {
Ok(slf) => slf,
Err(e) => {
println!("E: {:?}", e);
return Err(e.into());
}
Err(e) => return Err(e.into()),
};
let request = FrozenClientRequest {

View file

@ -1,8 +1,6 @@
use std::cell::{Ref, RefMut};
use std::fmt;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, marker::PhantomData, mem, pin::Pin};
use bytes::{Bytes, BytesMut};
use futures::{ready, Future, Stream};
@ -13,18 +11,19 @@ use coo_kie::{Cookie, ParseError as CookieParseError};
use crate::http::error::PayloadError;
use crate::http::header::CONTENT_LENGTH;
use crate::http::{Extensions, HttpMessage, Payload, PayloadStream, ResponseHead};
use crate::http::{HeaderMap, StatusCode, Version};
use crate::http::{HttpMessage, Payload, ResponseHead};
use crate::util::Extensions;
use super::error::JsonPayloadError;
/// Client Response
pub struct ClientResponse<S = PayloadStream> {
pub struct ClientResponse {
pub(crate) head: ResponseHead,
pub(crate) payload: Payload<S>,
pub(crate) payload: Payload,
}
impl<S> HttpMessage for ClientResponse<S> {
impl HttpMessage for ClientResponse {
fn message_headers(&self) -> &HeaderMap {
&self.head.headers
}
@ -59,9 +58,9 @@ impl<S> HttpMessage for ClientResponse<S> {
}
}
impl<S> ClientResponse<S> {
impl ClientResponse {
/// Create new Request instance
pub(crate) fn new(head: ResponseHead, payload: Payload<S>) -> Self {
pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self {
ClientResponse { head, payload }
}
@ -89,21 +88,13 @@ impl<S> ClientResponse<S> {
}
/// Set a body and return previous body value
pub fn map_body<F, U>(mut self, f: F) -> ClientResponse<U>
where
F: FnOnce(&mut ResponseHead, Payload<S>) -> Payload<U>,
{
let payload = f(&mut self.head, self.payload);
ClientResponse {
payload,
head: self.head,
}
pub fn set_payload(&mut self, payload: Payload) {
self.payload = payload;
}
/// Get response's payload
pub fn take_payload(&mut self) -> Payload<S> {
std::mem::take(&mut self.payload)
pub fn take_payload(&mut self) -> Payload {
mem::take(&mut self.payload)
}
/// Request extensions
@ -119,12 +110,9 @@ impl<S> ClientResponse<S> {
}
}
impl<S> ClientResponse<S>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
{
impl ClientResponse {
/// Loads http response's body.
pub fn body(&mut self) -> MessageBody<S> {
pub fn body(&mut self) -> MessageBody {
MessageBody::new(self)
}
@ -135,15 +123,12 @@ where
///
/// * content type is not `application/json`
/// * content length is greater than 256k
pub fn json<T: DeserializeOwned>(&mut self) -> JsonBody<S, T> {
pub fn json<T: DeserializeOwned>(&mut self) -> JsonBody<T> {
JsonBody::new(self)
}
}
impl<S> Stream for ClientResponse<S>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
impl Stream for ClientResponse {
type Item = Result<Bytes, PayloadError>;
fn poll_next(
@ -154,7 +139,7 @@ where
}
}
impl<S> fmt::Debug for ClientResponse<S> {
impl fmt::Debug for ClientResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?;
writeln!(f, " headers:")?;
@ -166,18 +151,15 @@ impl<S> fmt::Debug for ClientResponse<S> {
}
/// Future that resolves to a complete http message body.
pub struct MessageBody<S> {
pub struct MessageBody {
length: Option<usize>,
err: Option<PayloadError>,
fut: Option<ReadBody<S>>,
fut: Option<ReadBody>,
}
impl<S> MessageBody<S>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
{
impl MessageBody {
/// Create `MessageBody` for request.
pub fn new(res: &mut ClientResponse<S>) -> MessageBody<S> {
pub fn new(res: &mut ClientResponse) -> MessageBody {
let mut len = None;
if let Some(l) = res.headers().get(&CONTENT_LENGTH) {
if let Ok(s) = l.to_str() {
@ -215,10 +197,7 @@ where
}
}
impl<S> Future for MessageBody<S>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
impl Future for MessageBody {
type Output = Result<Bytes, PayloadError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@ -244,20 +223,19 @@ where
///
/// * content type is not `application/json`
/// * content length is greater than 64k
pub struct JsonBody<S, U> {
pub struct JsonBody<U> {
length: Option<usize>,
err: Option<JsonPayloadError>,
fut: Option<ReadBody<S>>,
fut: Option<ReadBody>,
_t: PhantomData<U>,
}
impl<S, U> JsonBody<S, U>
impl<U> JsonBody<U>
where
S: Stream<Item = Result<Bytes, PayloadError>>,
U: DeserializeOwned,
{
/// Create `JsonBody` for request.
pub fn new(req: &mut ClientResponse<S>) -> Self {
pub fn new(req: &mut ClientResponse) -> Self {
// check content-type
let json = if let Ok(Some(mime)) = req.mime_type() {
mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
@ -299,16 +277,10 @@ where
}
}
impl<T, U> Unpin for JsonBody<T, U>
where
T: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
U: DeserializeOwned,
{
}
impl<U> Unpin for JsonBody<U> where U: DeserializeOwned {}
impl<T, U> Future for JsonBody<T, U>
impl<U> Future for JsonBody<U>
where
T: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
U: DeserializeOwned,
{
type Output = Result<U, JsonPayloadError>;
@ -331,14 +303,14 @@ where
}
}
struct ReadBody<S> {
stream: Payload<S>,
struct ReadBody {
stream: Payload,
buf: BytesMut,
limit: usize,
}
impl<S> ReadBody<S> {
fn new(stream: Payload<S>, limit: usize) -> Self {
impl ReadBody {
fn new(stream: Payload, limit: usize) -> Self {
Self {
stream,
buf: BytesMut::with_capacity(std::cmp::min(limit, 32768)),
@ -347,10 +319,7 @@ impl<S> ReadBody<S> {
}
}
impl<S> Future for ReadBody<S>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
impl Future for ReadBody {
type Output = Result<Bytes, PayloadError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@ -432,7 +401,7 @@ mod tests {
#[ntex_rt::test]
async fn test_json_body() {
let mut req = TestResponse::default().finish();
let json = JsonBody::<_, MyObject>::new(&mut req).await;
let json = JsonBody::<MyObject>::new(&mut req).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default()
@ -441,7 +410,7 @@ mod tests {
header::HeaderValue::from_static("application/text"),
)
.finish();
let json = JsonBody::<_, MyObject>::new(&mut req).await;
let json = JsonBody::<MyObject>::new(&mut req).await;
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let mut req = TestResponse::default()
@ -455,7 +424,7 @@ mod tests {
)
.finish();
let json = JsonBody::<_, MyObject>::new(&mut req).limit(100).await;
let json = JsonBody::<MyObject>::new(&mut req).limit(100).await;
assert!(json_eq(
json.err().unwrap(),
JsonPayloadError::Payload(PayloadError::Overflow)
@ -473,7 +442,7 @@ mod tests {
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.finish();
let json = JsonBody::<_, MyObject>::new(&mut req).await;
let json = JsonBody::<MyObject>::new(&mut req).await;
assert_eq!(
json.ok().unwrap(),
MyObject {

View file

@ -19,9 +19,7 @@ use crate::rt::time::{delay_for, Delay};
#[cfg(feature = "compress")]
use crate::http::encoding::Decoder;
#[cfg(feature = "compress")]
use crate::http::header::ContentEncoding;
#[cfg(feature = "compress")]
use crate::http::{Payload, PayloadStream};
use crate::http::Payload;
use super::error::{FreezeRequestError, InvalidUrl, SendRequestError};
use super::response::ClientResponse;
@ -74,10 +72,6 @@ impl SendClientRequest {
}
impl Future for SendClientRequest {
#[cfg(feature = "compress")]
type Output =
Result<ClientResponse<Decoder<Payload<PayloadStream>>>, SendRequestError>;
#[cfg(not(feature = "compress"))]
type Output = Result<ClientResponse, SendRequestError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@ -95,20 +89,15 @@ impl Future for SendClientRequest {
let res = futures::ready!(Pin::new(send).poll(cx));
#[cfg(feature = "compress")]
let res = res.map(|res| {
res.map_body(|head, payload| {
if *_response_decompress {
Payload::Stream(Decoder::from_headers(
payload,
&head.headers,
))
} else {
Payload::Stream(Decoder::new(
payload,
ContentEncoding::Identity,
))
}
})
let res = res.map(|mut res| {
if *_response_decompress {
let payload = res.take_payload();
res.set_payload(Payload::from_stream(Decoder::from_headers(
payload,
&res.head.headers,
)))
}
res
});
Poll::Ready(res)

View file

@ -1,14 +1,12 @@
use std::cell::UnsafeCell;
use std::fmt;
use std::fmt::Write;
use std::ptr::copy_nonoverlapping;
use std::rc::Rc;
use std::time::Duration;
use std::{
cell::UnsafeCell, fmt, fmt::Write, ptr::copy_nonoverlapping, rc::Rc, time::Duration,
};
use bytes::BytesMut;
use futures::{future, FutureExt};
use time::OffsetDateTime;
use crate::framed::Timer;
use crate::rt::time::{delay_for, delay_until, Delay, Instant};
// "Sun, 06 Nov 1994 08:49:37 GMT".len()
@ -45,12 +43,13 @@ impl From<Option<usize>> for KeepAlive {
pub struct ServiceConfig(pub(super) Rc<Inner>);
pub(super) struct Inner {
pub(super) keep_alive: Option<Duration>,
pub(super) keep_alive: u64,
pub(super) client_timeout: u64,
pub(super) client_disconnect: u64,
pub(super) ka_enabled: bool,
pub(super) timer: DateService,
pub(super) ssl_handshake_timeout: u64,
pub(super) timer_h1: Timer,
}
impl Clone for ServiceConfig {
@ -79,9 +78,9 @@ impl ServiceConfig {
KeepAlive::Disabled => (0, false),
};
let keep_alive = if ka_enabled && keep_alive > 0 {
Some(Duration::from_secs(keep_alive))
keep_alive
} else {
None
0
};
ServiceConfig(Rc::new(Inner {
@ -91,6 +90,7 @@ impl ServiceConfig {
client_disconnect,
ssl_handshake_timeout,
timer: DateService::new(),
timer_h1: Timer::default(),
}))
}
}
@ -99,11 +99,12 @@ pub(super) struct DispatcherConfig<S, X, U> {
pub(super) service: S,
pub(super) expect: X,
pub(super) upgrade: Option<U>,
pub(super) keep_alive: Option<Duration>,
pub(super) keep_alive: u64,
pub(super) client_timeout: u64,
pub(super) client_disconnect: u64,
pub(super) ka_enabled: bool,
pub(super) timer: DateService,
pub(super) timer_h1: Timer,
}
impl<S, X, U> DispatcherConfig<S, X, U> {
@ -122,6 +123,7 @@ impl<S, X, U> DispatcherConfig<S, X, U> {
client_disconnect: cfg.0.client_disconnect,
ka_enabled: cfg.0.ka_enabled,
timer: cfg.0.timer.clone(),
timer_h1: cfg.0.timer_h1.clone(),
}
}
@ -130,37 +132,12 @@ impl<S, X, U> DispatcherConfig<S, X, U> {
self.ka_enabled
}
/// Client timeout for first request.
pub(super) fn client_timer(&self) -> Option<Delay> {
let delay_time = self.client_timeout;
if delay_time != 0 {
Some(delay_until(
self.timer.now() + Duration::from_millis(delay_time),
))
} else {
None
}
}
/// Client disconnect timer
pub(super) fn client_disconnect_timer(&self) -> Option<Instant> {
let delay = self.client_disconnect;
if delay != 0 {
Some(self.timer.now() + Duration::from_millis(delay))
} else {
None
}
}
/// Return state of connection keep-alive timer
pub(super) fn keep_alive_timer_enabled(&self) -> bool {
self.keep_alive.is_some()
}
/// Return keep-alive timer delay is configured.
pub(super) fn keep_alive_timer(&self) -> Option<Delay> {
if let Some(ka) = self.keep_alive {
Some(delay_until(self.timer.now() + ka))
if self.keep_alive != 0 {
Some(delay_until(
self.timer.now() + Duration::from_secs(self.keep_alive),
))
} else {
None
}
@ -168,8 +145,8 @@ impl<S, X, U> DispatcherConfig<S, X, U> {
/// Keep-alive expire time
pub(super) fn keep_alive_expire(&self) -> Option<Instant> {
if let Some(ka) = self.keep_alive {
Some(self.timer.now() + ka)
if self.keep_alive != 0 {
Some(self.timer.now() + Duration::from_secs(self.keep_alive))
} else {
None
}

View file

@ -16,7 +16,7 @@ use super::body::Body;
use super::response::Response;
/// Error that can be converted to `Response`
pub trait ResponseError: fmt::Display + fmt::Debug + 'static {
pub trait ResponseError: fmt::Display + fmt::Debug {
/// Create response for error
///
/// Internal server error is generated by default.
@ -32,6 +32,12 @@ pub trait ResponseError: fmt::Display + fmt::Debug + 'static {
}
}
impl<'a, T: ResponseError> ResponseError for &'a T {
fn error_response(&self) -> Response {
(*self).error_response()
}
}
impl<T: ResponseError> From<T> for Response {
fn from(err: T) -> Response {
let resp = err.error_response();
@ -180,8 +186,9 @@ pub enum DispatchError {
/// Service error
Service(Box<dyn ResponseError>),
#[from(ignore)]
/// Upgrade service error
Upgrade,
Upgrade(Box<dyn std::error::Error>),
/// An `io::Error` that occurred while trying to read or write to a network
/// stream.
@ -192,6 +199,11 @@ pub enum DispatchError {
#[display(fmt = "Parse error: {}", _0)]
Parse(ParseError),
/// Http response encoding error.
#[display(fmt = "Encode error: {}", _0)]
#[from(ignore)]
Encode(io::Error),
/// Http/2 error
#[display(fmt = "{}", _0)]
H2(h2::Error),
@ -212,6 +224,10 @@ pub enum DispatchError {
#[display(fmt = "Malformed request")]
MalformedRequest,
/// Response body processing error
#[display(fmt = "Response body processing error: {}", _0)]
ResponsePayload(Box<dyn std::error::Error>),
/// Internal error
#[display(fmt = "Internal error")]
InternalError,

View file

@ -1,4 +1,4 @@
use std::io;
use std::{cell::Cell, cell::RefCell, io};
use bitflags::bitflags;
use bytes::{Bytes, BytesMut};
@ -34,12 +34,12 @@ pub struct ClientPayloadCodec {
struct ClientCodecInner {
timer: DateService,
decoder: decoder::MessageDecoder<ResponseHead>,
payload: Option<PayloadDecoder>,
version: Version,
ctype: ConnectionType,
payload: RefCell<Option<PayloadDecoder>>,
version: Cell<Version>,
ctype: Cell<ConnectionType>,
// encoder part
flags: Flags,
flags: Cell<Flags>,
encoder: encoder::MessageEncoder<RequestHeadType>,
}
@ -63,11 +63,10 @@ impl ClientCodec {
inner: ClientCodecInner {
timer,
decoder: decoder::MessageDecoder::default(),
payload: None,
version: Version::HTTP_11,
ctype: ConnectionType::Close,
flags,
payload: RefCell::new(None),
version: Cell::new(Version::HTTP_11),
ctype: Cell::new(ConnectionType::Close),
flags: Cell::new(flags),
encoder: encoder::MessageEncoder::default(),
},
}
@ -75,19 +74,19 @@ impl ClientCodec {
/// Check if request is upgrade
pub fn upgrade(&self) -> bool {
self.inner.ctype == ConnectionType::Upgrade
self.inner.ctype.get() == ConnectionType::Upgrade
}
/// Check if last response is keep-alive
pub fn keepalive(&self) -> bool {
self.inner.ctype == ConnectionType::KeepAlive
self.inner.ctype.get() == ConnectionType::KeepAlive
}
/// Check last request's message type
pub fn message_type(&self) -> MessageType {
if self.inner.flags.contains(Flags::STREAM) {
if self.inner.flags.get().contains(Flags::STREAM) {
MessageType::Stream
} else if self.inner.payload.is_none() {
} else if self.inner.payload.borrow().is_none() {
MessageType::None
} else {
MessageType::Payload
@ -103,7 +102,7 @@ impl ClientCodec {
impl ClientPayloadCodec {
/// Check if last response is keep-alive
pub fn keepalive(&self) -> bool {
self.inner.ctype == ConnectionType::KeepAlive
self.inner.ctype.get() == ConnectionType::KeepAlive
}
/// Transform payload codec to a message codec
@ -116,30 +115,37 @@ impl Decoder for ClientCodec {
type Item = ResponseHead;
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
debug_assert!(!self.inner.payload.is_some(), "Payload decoder is set");
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
debug_assert!(
!self.inner.payload.borrow().is_some(),
"Payload decoder is set"
);
if let Some((req, payload)) = self.inner.decoder.decode(src)? {
if let Some(ctype) = req.ctype() {
// do not use peer's keep-alive
self.inner.ctype = if ctype == ConnectionType::KeepAlive {
self.inner.ctype
} else {
ctype
if ctype != ConnectionType::KeepAlive {
self.inner.ctype.set(ctype);
};
}
if !self.inner.flags.contains(Flags::HEAD) {
if !self.inner.flags.get().contains(Flags::HEAD) {
match payload {
PayloadType::None => self.inner.payload = None,
PayloadType::Payload(pl) => self.inner.payload = Some(pl),
PayloadType::None => {
self.inner.payload.borrow_mut().take();
}
PayloadType::Payload(pl) => {
*self.inner.payload.borrow_mut() = Some(pl)
}
PayloadType::Stream(pl) => {
self.inner.payload = Some(pl);
self.inner.flags.insert(Flags::STREAM);
*self.inner.payload.borrow_mut() = Some(pl);
let mut flags = self.inner.flags.get();
flags.insert(Flags::STREAM);
self.inner.flags.set(flags);
}
}
} else {
self.inner.payload = None;
self.inner.payload.borrow_mut().take();
}
reserve_readbuf(src);
Ok(Some(req))
@ -153,19 +159,27 @@ impl Decoder for ClientPayloadCodec {
type Item = Option<Bytes>;
type Error = PayloadError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
debug_assert!(
self.inner.payload.is_some(),
self.inner.payload.borrow().is_some(),
"Payload decoder is not specified"
);
Ok(match self.inner.payload.as_mut().unwrap().decode(src)? {
let item = self
.inner
.payload
.borrow_mut()
.as_mut()
.unwrap()
.decode(src)?;
Ok(match item {
Some(PayloadItem::Chunk(chunk)) => {
reserve_readbuf(src);
Some(Some(chunk))
}
Some(PayloadItem::Eof) => {
self.inner.payload.take();
self.inner.payload.borrow_mut().take();
Some(None)
}
None => None,
@ -177,23 +191,19 @@ impl Encoder for ClientCodec {
type Item = Message<(RequestHeadType, BodySize)>;
type Error = io::Error;
fn encode(
&mut self,
item: Self::Item,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Item((mut head, length)) => {
let inner = &mut self.inner;
inner.version = head.as_ref().version;
inner
.flags
.set(Flags::HEAD, head.as_ref().method == Method::HEAD);
let inner = &self.inner;
inner.version.set(head.as_ref().version);
let mut flags = inner.flags.get();
flags.set(Flags::HEAD, head.as_ref().method == Method::HEAD);
inner.flags.set(flags);
// connection status
inner.ctype = match head.as_ref().connection_type() {
inner.ctype.set(match head.as_ref().connection_type() {
ConnectionType::KeepAlive => {
if inner.flags.contains(Flags::KEEPALIVE_ENABLED) {
if inner.flags.get().contains(Flags::KEEPALIVE_ENABLED) {
ConnectionType::KeepAlive
} else {
ConnectionType::Close
@ -201,16 +211,16 @@ impl Encoder for ClientCodec {
}
ConnectionType::Upgrade => ConnectionType::Upgrade,
ConnectionType::Close => ConnectionType::Close,
};
});
inner.encoder.encode(
dst,
&mut head,
false,
false,
inner.version,
inner.version.get(),
length,
inner.ctype,
inner.ctype.get(),
&inner.timer,
)?;
}

View file

@ -1,4 +1,4 @@
use std::{fmt, io};
use std::{cell::Cell, fmt, io};
use bitflags::bitflags;
use bytes::BytesMut;
@ -12,15 +12,13 @@ use crate::http::message::ConnectionType;
use crate::http::request::Request;
use crate::http::response::Response;
use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
use super::{decoder, encoder};
use super::{Message, MessageType};
use super::{decoder, decoder::PayloadType, encoder, Message};
bitflags! {
struct Flags: u8 {
const HEAD = 0b0000_0001;
const KEEPALIVE_ENABLED = 0b0000_0010;
const STREAM = 0b0000_0100;
const STREAM = 0b0000_0010;
const KEEPALIVE_ENABLED = 0b0000_0100;
}
}
@ -28,12 +26,11 @@ bitflags! {
pub struct Codec {
timer: DateService,
decoder: decoder::MessageDecoder<Request>,
payload: Option<PayloadDecoder>,
version: Version,
ctype: ConnectionType,
version: Cell<Version>,
ctype: Cell<ConnectionType>,
// encoder part
flags: Flags,
flags: Cell<Flags>,
encoder: encoder::MessageEncoder<Response<()>>,
}
@ -43,6 +40,19 @@ impl Default for Codec {
}
}
impl Clone for Codec {
fn clone(&self) -> Self {
Codec {
timer: self.timer.clone(),
decoder: self.decoder.clone(),
version: self.version.clone(),
ctype: self.ctype.clone(),
flags: self.flags.clone(),
encoder: self.encoder.clone(),
}
}
}
impl fmt::Debug for Codec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "h1::Codec({:?})", self.flags)
@ -61,12 +71,11 @@ impl Codec {
};
Codec {
flags,
timer,
flags: Cell::new(flags),
decoder: decoder::MessageDecoder::default(),
payload: None,
version: Version::HTTP_11,
ctype: ConnectionType::Close,
version: Cell::new(Version::HTTP_11),
ctype: Cell::new(ConnectionType::Close),
encoder: encoder::MessageEncoder::default(),
}
}
@ -74,31 +83,19 @@ impl Codec {
#[inline]
/// Check if request is upgrade
pub fn upgrade(&self) -> bool {
self.ctype == ConnectionType::Upgrade
self.ctype.get() == ConnectionType::Upgrade
}
#[inline]
/// Check if last response is keep-alive
pub fn keepalive(&self) -> bool {
self.ctype == ConnectionType::KeepAlive
self.ctype.get() == ConnectionType::KeepAlive
}
#[inline]
/// Check if keep-alive enabled on server level
pub fn keepalive_enabled(&self) -> bool {
self.flags.contains(Flags::KEEPALIVE_ENABLED)
}
#[inline]
/// Check last request's message type
pub fn message_type(&self) -> MessageType {
if self.flags.contains(Flags::STREAM) {
MessageType::Stream
} else if self.payload.is_none() {
MessageType::None
} else {
MessageType::Payload
}
self.flags.get().contains(Flags::KEEPALIVE_ENABLED)
}
#[inline]
@ -106,41 +103,36 @@ impl Codec {
pub fn set_date_header(&self, dst: &mut BytesMut) {
self.timer.set_date_header(dst)
}
fn insert_flags(&self, f: Flags) {
let mut flags = self.flags.get();
flags.insert(f);
self.flags.set(flags);
}
}
impl Decoder for Codec {
type Item = Message<Request>;
type Item = (Request, PayloadType);
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if let Some(ref mut payload) = self.payload {
Ok(match payload.decode(src)? {
Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))),
Some(PayloadItem::Eof) => {
self.payload.take();
Some(Message::Chunk(None))
}
None => None,
})
} else if let Some((req, payload)) = self.decoder.decode(src)? {
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if let Some((req, payload)) = self.decoder.decode(src)? {
let head = req.head();
self.flags.set(Flags::HEAD, head.method == Method::HEAD);
self.version = head.version;
self.ctype = head.connection_type();
if self.ctype == ConnectionType::KeepAlive
&& !self.flags.contains(Flags::KEEPALIVE_ENABLED)
let mut flags = self.flags.get();
flags.set(Flags::HEAD, head.method == Method::HEAD);
self.flags.set(flags);
self.version.set(head.version);
self.ctype.set(head.connection_type());
if self.ctype.get() == ConnectionType::KeepAlive
&& !flags.contains(Flags::KEEPALIVE_ENABLED)
{
self.ctype = ConnectionType::Close
self.ctype.set(ConnectionType::Close)
}
match payload {
PayloadType::None => self.payload = None,
PayloadType::Payload(pl) => self.payload = Some(pl),
PayloadType::Stream(pl) => {
self.payload = Some(pl);
self.flags.insert(Flags::STREAM);
}
if let PayloadType::Stream(_) = payload {
self.insert_flags(Flags::STREAM)
}
Ok(Some(Message::Item(req)))
Ok(Some((req, payload)))
} else {
Ok(None)
}
@ -151,36 +143,28 @@ impl Encoder for Codec {
type Item = Message<(Response<()>, BodySize)>;
type Error = io::Error;
fn encode(
&mut self,
item: Self::Item,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
fn encode(&self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Item((mut res, length)) => {
// set response version
res.head_mut().version = self.version;
res.head_mut().version = self.version.get();
// connection status
self.ctype = if let Some(ct) = res.head().ctype() {
if ct == ConnectionType::KeepAlive {
self.ctype
} else {
ct
if let Some(ct) = res.head().ctype() {
if ct != ConnectionType::KeepAlive {
self.ctype.set(ct)
}
} else {
self.ctype
};
}
// encode message
self.encoder.encode(
dst,
&mut res,
self.flags.contains(Flags::HEAD),
self.flags.contains(Flags::STREAM),
self.version,
self.flags.get().contains(Flags::HEAD),
self.flags.get().contains(Flags::STREAM),
self.version.get(),
length,
self.ctype,
self.ctype.get(),
&self.timer,
)?;
// self.headers_size = (dst.len() - len) as u32;
@ -198,22 +182,25 @@ impl Encoder for Codec {
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use bytes::{Bytes, BytesMut};
use super::*;
use crate::http::{HttpMessage, Method};
use crate::http::{h1::PayloadItem, HttpMessage, Method};
#[test]
fn test_http_request_chunked_payload_and_next_message() {
let mut codec = Codec::default();
let codec = Codec::default();
assert!(format!("{:?}", codec).contains("h1::Codec"));
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let item = codec.decode(&mut buf).unwrap().unwrap();
let req = item.message();
let (req, pl) = codec.decode(&mut buf).unwrap().unwrap();
let pl = match pl {
PayloadType::Payload(pl) => pl,
_ => panic!(),
};
assert_eq!(req.method(), Method::GET);
assert!(req.chunked().unwrap());
@ -225,22 +212,21 @@ mod tests {
.iter(),
);
let msg = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"data");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg, PayloadItem::Chunk(Bytes::from_static(b"data")));
let msg = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg.chunk().as_ref(), b"line");
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg, PayloadItem::Chunk(Bytes::from_static(b"line")));
let msg = codec.decode(&mut buf).unwrap().unwrap();
assert!(msg.eof());
let msg = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(msg, PayloadItem::Eof);
// decode next message
let item = codec.decode(&mut buf).unwrap().unwrap();
let req = item.message();
let (req, _pl) = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(*req.method(), Method::POST);
assert!(req.chunked().unwrap());
let mut codec = Codec::default();
let codec = Codec::default();
let mut buf = BytesMut::from(
"GET /test HTTP/1.1\r\n\
connection: upgrade\r\n\r\n",

View file

@ -1,12 +1,10 @@
use std::convert::TryFrom;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::task::Poll;
use std::{
cell::Cell, convert::TryFrom, marker::PhantomData, mem::MaybeUninit, task::Poll,
};
use bytes::{Buf, Bytes, BytesMut};
use http::header::{HeaderName, HeaderValue};
use http::{header, Method, StatusCode, Uri, Version};
use log::{debug, error, trace};
use crate::codec::Decoder;
use crate::http::error::ParseError;
@ -23,7 +21,7 @@ pub(super) struct MessageDecoder<T: MessageType>(PhantomData<T>);
#[derive(Debug)]
/// Incoming request type
pub(super) enum PayloadType {
pub enum PayloadType {
None,
Payload(PayloadDecoder),
Stream(PayloadDecoder),
@ -35,11 +33,17 @@ impl<T: MessageType> Default for MessageDecoder<T> {
}
}
impl<T: MessageType> Clone for MessageDecoder<T> {
fn clone(&self) -> Self {
MessageDecoder(PhantomData)
}
}
impl<T: MessageType> Decoder for MessageDecoder<T> {
type Item = (T, PayloadType);
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
T::decode(src)
}
}
@ -91,11 +95,11 @@ pub(super) trait MessageType: Sized {
content_length = Some(len);
}
} else {
debug!("illegal Content-Length: {:?}", s);
log::debug!("illegal Content-Length: {:?}", s);
return Err(ParseError::Header);
}
} else {
debug!("illegal Content-Length: {:?}", value);
log::debug!("illegal Content-Length: {:?}", value);
return Err(ParseError::Header);
}
}
@ -290,7 +294,7 @@ impl MessageType for ResponseHead {
}
httparse::Status::Partial => {
return if src.len() >= MAX_BUFFER_SIZE {
error!("MAX_BUFFER_SIZE unprocessed data reached, closing");
log::error!("MAX_BUFFER_SIZE unprocessed data reached, closing");
Err(ParseError::TooLarge)
} else {
Ok(None)
@ -351,7 +355,7 @@ impl HeaderIndex {
#[derive(Debug, Clone, PartialEq)]
/// Http payload item
pub(super) enum PayloadItem {
pub enum PayloadItem {
Chunk(Bytes),
Eof,
}
@ -361,29 +365,31 @@ pub(super) enum PayloadItem {
/// If a message body does not include a Transfer-Encoding, it *should*
/// include a Content-Length header.
#[derive(Debug, Clone, PartialEq)]
pub(super) struct PayloadDecoder {
kind: Kind,
pub struct PayloadDecoder {
kind: Cell<Kind>,
}
impl PayloadDecoder {
pub(super) fn length(x: u64) -> PayloadDecoder {
PayloadDecoder {
kind: Kind::Length(x),
kind: Cell::new(Kind::Length(x)),
}
}
pub(super) fn chunked() -> PayloadDecoder {
PayloadDecoder {
kind: Kind::Chunked(ChunkedState::Size, 0),
kind: Cell::new(Kind::Chunked(ChunkedState::Size, 0)),
}
}
pub(super) fn eof() -> PayloadDecoder {
PayloadDecoder { kind: Kind::Eof }
PayloadDecoder {
kind: Cell::new(Kind::Eof),
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Copy, Clone, PartialEq)]
enum Kind {
/// A Reader used when a Content-Length header is passed with a positive
/// integer.
@ -407,7 +413,7 @@ enum Kind {
Eof,
}
#[derive(Debug, PartialEq, Clone)]
#[derive(Debug, PartialEq, Copy, Clone)]
enum ChunkedState {
Size,
SizeLws,
@ -425,8 +431,10 @@ impl Decoder for PayloadDecoder {
type Item = PayloadItem;
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.kind {
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut kind = self.kind.get();
match kind {
Kind::Length(ref mut remaining) => {
if *remaining == 0 {
Ok(Some(PayloadItem::Eof))
@ -443,30 +451,35 @@ impl Decoder for PayloadDecoder {
buf = src.split_to(*remaining as usize).freeze();
*remaining = 0;
};
trace!("Length read: {}", buf.len());
self.kind.set(kind);
log::trace!("Length read: {}", buf.len());
Ok(Some(PayloadItem::Chunk(buf)))
}
}
Kind::Chunked(ref mut state, ref mut size) => {
loop {
let result = loop {
let mut buf = None;
// advances the chunked state
*state = match state.step(src, size, &mut buf) {
Poll::Pending => return Ok(None),
Poll::Pending => break Ok(None),
Poll::Ready(Ok(state)) => state,
Poll::Ready(Err(e)) => return Err(e),
Poll::Ready(Err(e)) => break Err(e),
};
if *state == ChunkedState::End {
trace!("End of chunked stream");
return Ok(Some(PayloadItem::Eof));
log::trace!("End of chunked stream");
break Ok(Some(PayloadItem::Eof));
}
if let Some(buf) = buf {
return Ok(Some(PayloadItem::Chunk(buf)));
break Ok(Some(PayloadItem::Chunk(buf)));
}
if src.is_empty() {
return Ok(None);
break Ok(None);
}
}
};
self.kind.set(kind);
result
}
Kind::Eof => {
if src.is_empty() {
@ -544,7 +557,7 @@ impl ChunkedState {
}
fn read_size_lws(rdr: &mut BytesMut) -> Poll<Result<ChunkedState, ParseError>> {
trace!("read_size_lws");
log::trace!("read_size_lws");
match byte!(rdr) {
// LWS can follow the chunk size, but no more digits can come
b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
@ -577,7 +590,7 @@ impl ChunkedState {
rem: &mut u64,
buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, ParseError>> {
trace!("Chunked read, remaining={:?}", rem);
log::trace!("Chunked read, remaining={:?}", rem);
let len = rdr.len() as u64;
if len == 0 {
@ -693,7 +706,7 @@ mod tests {
fn test_parse() {
let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n\r\n");
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
match reader.decode(&mut buf) {
Ok(Some((req, _))) => {
assert_eq!(req.version(), Version::HTTP_11);
@ -708,7 +721,7 @@ mod tests {
fn test_parse_partial() {
let mut buf = BytesMut::from("PUT /test HTTP/1");
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
assert!(reader.decode(&mut buf).unwrap().is_none());
buf.extend(b".1\r\n\r\n");
@ -722,7 +735,7 @@ mod tests {
fn test_parse_post() {
let mut buf = BytesMut::from("POST /test2 HTTP/1.0\r\n\r\n");
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
let (req, _) = reader.decode(&mut buf).unwrap().unwrap();
assert_eq!(req.version(), Version::HTTP_10);
assert_eq!(*req.method(), Method::POST);
@ -734,9 +747,9 @@ mod tests {
let mut buf =
BytesMut::from("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody");
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let pl = pl.unwrap();
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@ -751,9 +764,9 @@ mod tests {
let mut buf =
BytesMut::from("\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody");
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let pl = pl.unwrap();
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@ -766,7 +779,7 @@ mod tests {
#[test]
fn test_parse_partial_eof() {
let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n");
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
assert!(reader.decode(&mut buf).unwrap().is_none());
buf.extend(b"\r\n");
@ -780,7 +793,7 @@ mod tests {
fn test_headers_split_field() {
let mut buf = BytesMut::from("GET /test HTTP/1.1\r\n");
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
assert! { reader.decode(&mut buf).unwrap().is_none() }
buf.extend(b"t");
@ -810,7 +823,7 @@ mod tests {
Set-Cookie: c1=cookie1\r\n\
Set-Cookie: c2=cookie2\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
let (req, _) = reader.decode(&mut buf).unwrap().unwrap();
let val: Vec<_> = req
@ -1037,7 +1050,7 @@ mod tests {
upgrade: websocket\r\n\r\n\
some raw data",
);
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
assert_eq!(req.head().connection_type(), ConnectionType::Upgrade);
assert!(req.upgrade());
@ -1086,9 +1099,9 @@ mod tests {
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let pl = pl.unwrap();
assert!(req.chunked().unwrap());
buf.extend(b"4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n");
@ -1109,9 +1122,9 @@ mod tests {
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let pl = pl.unwrap();
assert!(req.chunked().unwrap());
buf.extend(
@ -1140,9 +1153,9 @@ mod tests {
transfer-encoding: chunked\r\n\r\n",
);
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let pl = pl.unwrap();
assert!(req.chunked().unwrap());
buf.extend(b"4\r\n1111\r\n");
@ -1185,9 +1198,9 @@ mod tests {
transfer-encoding: chunked\r\n\r\n"[..],
);
let mut reader = MessageDecoder::<Request>::default();
let reader = MessageDecoder::<Request>::default();
let (msg, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let pl = pl.unwrap();
assert!(msg.chunked().unwrap());
buf.extend(b"4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n")
@ -1203,9 +1216,9 @@ mod tests {
fn test_response_http10_read_until_eof() {
let mut buf = BytesMut::from(&"HTTP/1.0 200 Ok\r\n\r\ntest data"[..]);
let mut reader = MessageDecoder::<ResponseHead>::default();
let reader = MessageDecoder::<ResponseHead>::default();
let (_msg, pl) = reader.decode(&mut buf).unwrap().unwrap();
let mut pl = pl.unwrap();
let pl = pl.unwrap();
let chunk = pl.decode(&mut buf).unwrap().unwrap();
assert_eq!(chunk, PayloadItem::Chunk(Bytes::from_static(b"test data")));

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,6 @@
use std::io::Write;
use std::marker::PhantomData;
use std::ptr::copy_nonoverlapping;
use std::{cmp, io, mem, ptr, slice};
use std::{cell::Cell, cmp, io, mem, ptr, ptr::copy_nonoverlapping, slice};
use bytes::{BufMut, BytesMut};
@ -18,7 +17,7 @@ const AVERAGE_HEADER_SIZE: usize = 30;
#[derive(Debug)]
pub(super) struct MessageEncoder<T: MessageType> {
pub(super) length: BodySize,
pub(super) te: TransferEncoding,
pub(super) te: Cell<TransferEncoding>,
_t: PhantomData<T>,
}
@ -26,7 +25,17 @@ impl<T: MessageType> Default for MessageEncoder<T> {
fn default() -> Self {
MessageEncoder {
length: BodySize::None,
te: TransferEncoding::empty(),
te: Cell::new(TransferEncoding::empty()),
_t: PhantomData,
}
}
}
impl<T: MessageType> Clone for MessageEncoder<T> {
fn clone(&self) -> Self {
MessageEncoder {
length: self.length,
te: self.te.clone(),
_t: PhantomData,
}
}
@ -41,10 +50,10 @@ pub(super) trait MessageType: Sized {
fn chunked(&self) -> bool;
fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()>;
fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()>;
fn encode_headers(
&mut self,
&self,
dst: &mut BytesMut,
version: Version,
mut length: BodySize,
@ -208,7 +217,7 @@ impl MessageType for Response<()> {
None
}
fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()> {
let head = self.head();
let reason = head.reason().as_bytes();
dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
@ -237,7 +246,7 @@ impl MessageType for RequestHeadType {
self.extra_headers()
}
fn encode_status(&mut self, dst: &mut BytesMut) -> io::Result<()> {
fn encode_status(&self, dst: &mut BytesMut) -> io::Result<()> {
let head = self.as_ref();
dst.reserve(256 + head.headers.len() * AVERAGE_HEADER_SIZE);
write!(
@ -264,20 +273,26 @@ impl MessageType for RequestHeadType {
impl<T: MessageType> MessageEncoder<T> {
/// Encode message
pub(super) fn encode_chunk(
&mut self,
&self,
msg: &[u8],
buf: &mut BytesMut,
) -> io::Result<bool> {
self.te.encode(msg, buf)
let mut te = self.te.get();
let result = te.encode(msg, buf);
self.te.set(te);
result
}
/// Encode eof
pub(super) fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
self.te.encode_eof(buf)
pub(super) fn encode_eof(&self, buf: &mut BytesMut) -> io::Result<()> {
let mut te = self.te.get();
let result = te.encode_eof(buf);
self.te.set(te);
result
}
pub(super) fn encode(
&mut self,
&self,
dst: &mut BytesMut,
message: &mut T,
head: bool,
@ -289,7 +304,7 @@ impl<T: MessageType> MessageEncoder<T> {
) -> io::Result<()> {
// transfer encoding
if !head {
self.te = match length {
self.te.set(match length {
BodySize::Empty => TransferEncoding::empty(),
BodySize::Sized(len) => TransferEncoding::length(len),
BodySize::Stream => {
@ -300,9 +315,9 @@ impl<T: MessageType> MessageEncoder<T> {
}
}
BodySize::None => TransferEncoding::empty(),
};
});
} else {
self.te = TransferEncoding::empty();
self.te.set(TransferEncoding::empty());
}
message.encode_status(dst)?;
@ -311,12 +326,12 @@ impl<T: MessageType> MessageEncoder<T> {
}
/// Encoders to handle different Transfer-Encodings.
#[derive(Debug)]
#[derive(Debug, Copy, Clone)]
pub(super) struct TransferEncoding {
kind: TransferEncodingKind,
}
#[derive(Debug, PartialEq, Clone)]
#[derive(Debug, PartialEq, Clone, Copy)]
enum TransferEncodingKind {
/// An Encoder for when Transfer-Encoding includes `chunked`.
Chunked(bool),
@ -368,14 +383,15 @@ impl TransferEncoding {
buf.extend_from_slice(msg);
Ok(eof)
}
TransferEncodingKind::Chunked(ref mut eof) => {
if *eof {
TransferEncodingKind::Chunked(eof) => {
if eof {
return Ok(true);
}
if msg.is_empty() {
*eof = true;
let result = if msg.is_empty() {
buf.extend_from_slice(b"0\r\n\r\n");
self.kind = TransferEncodingKind::Chunked(true);
true
} else {
writeln!(helpers::Writer(buf), "{:X}\r", msg.len())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
@ -383,20 +399,22 @@ impl TransferEncoding {
buf.reserve(msg.len() + 2);
buf.extend_from_slice(msg);
buf.extend_from_slice(b"\r\n");
}
Ok(*eof)
false
};
Ok(result)
}
TransferEncodingKind::Length(ref mut remaining) => {
if *remaining > 0 {
TransferEncodingKind::Length(mut remaining) => {
if remaining > 0 {
if msg.is_empty() {
return Ok(*remaining == 0);
return Ok(remaining == 0);
}
let len = cmp::min(*remaining, msg.len() as u64);
let len = cmp::min(remaining, msg.len() as u64);
buf.extend_from_slice(&msg[..len as usize]);
*remaining -= len as u64;
Ok(*remaining == 0)
remaining -= len as u64;
self.kind = TransferEncodingKind::Length(remaining);
Ok(remaining == 0)
} else {
Ok(true)
}
@ -416,10 +434,10 @@ impl TransferEncoding {
Ok(())
}
}
TransferEncodingKind::Chunked(ref mut eof) => {
if !*eof {
*eof = true;
TransferEncodingKind::Chunked(eof) => {
if !eof {
buf.extend_from_slice(b"0\r\n\r\n");
self.kind = TransferEncodingKind::Chunked(true);
}
Ok(())
}
@ -614,7 +632,7 @@ mod tests {
);
extra_headers.insert(DATE, HeaderValue::from_static("date"));
let mut head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers));
let head = RequestHeadType::Rc(Rc::new(head), Some(extra_headers));
let _ = head.encode_headers(
&mut bytes,

View file

@ -13,6 +13,7 @@ mod upgrade;
pub use self::client::{ClientCodec, ClientPayloadCodec};
pub use self::codec::Codec;
pub use self::decoder::{PayloadDecoder, PayloadItem, PayloadType};
pub use self::expect::ExpectHandler;
pub use self::payload::Payload;
pub use self::service::{H1Service, H1ServiceHandler};
@ -54,33 +55,3 @@ pub(crate) fn reserve_readbuf(src: &mut BytesMut) {
src.reserve(HW - cap);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::request::Request;
impl Message<Request> {
pub fn message(self) -> Request {
match self {
Message::Item(req) => req,
_ => panic!("error"),
}
}
pub fn chunk(self) -> Bytes {
match self {
Message::Chunk(Some(data)) => data,
_ => panic!("error"),
}
}
pub fn eof(self) -> bool {
match self {
Message::Chunk(None) => true,
Message::Chunk(Some(_)) => false,
_ => panic!("error"),
}
}
}
}

View file

@ -119,8 +119,8 @@ impl PayloadSender {
}
}
pub(super) fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus {
// we check need_read only if Payload (other side) is alive,
pub(super) fn poll_data_required(&self, cx: &mut Context<'_>) -> PayloadStatus {
// we check only if Payload (other side) is alive,
// otherwise always return true (consume payload)
if let Some(shared) = self.inner.upgrade() {
if shared.borrow().need_read {

View file

@ -1,11 +1,11 @@
use std::marker::PhantomData;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::{fmt, net};
use std::{
error::Error, fmt, marker::PhantomData, net, rc::Rc, task::Context, task::Poll,
};
use futures::future::{ok, FutureExt, LocalBoxFuture};
use crate::codec::{AsyncRead, AsyncWrite, Framed};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::framed::State as IoState;
use crate::http::body::MessageBody;
use crate::http::config::{DispatcherConfig, ServiceConfig};
use crate::http::error::{DispatchError, ResponseError};
@ -34,7 +34,7 @@ pub struct H1Service<T, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
impl<T, S, B> H1Service<T, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
B: MessageBody,
@ -59,21 +59,17 @@ where
impl<S, B, X, U> H1Service<TcpStream, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Future: 'static,
B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<TcpStream, Codec>),
Response = (),
>,
U::Error: fmt::Display + ResponseError,
U: ServiceFactory<Config = (), Request = (Request, IoState, Codec), Response = ()>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
{
@ -105,21 +101,21 @@ mod openssl {
impl<S, B, X, U> H1Service<SslStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Future: 'static,
B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<SslStream<TcpStream>, Codec>),
Request = (Request, IoState, Codec),
Response = (),
>,
U::Error: fmt::Display + ResponseError,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
{
@ -136,7 +132,7 @@ mod openssl {
> {
pipeline_factory(
Acceptor::new(acceptor)
.timeout(self.handshake_timeout)
.timeout((self.handshake_timeout as u64) * 1000)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
@ -159,21 +155,21 @@ mod rustls {
impl<S, B, X, U> H1Service<TlsStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Future: 'static,
B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<TlsStream<TcpStream>, Codec>),
Request = (Request, IoState, Codec),
Response = (),
>,
U::Error: fmt::Display + ResponseError,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
{
@ -190,7 +186,7 @@ mod rustls {
> {
pipeline_factory(
Acceptor::new(config)
.timeout(self.handshake_timeout)
.timeout((self.handshake_timeout as u64) * 1000)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
@ -206,7 +202,7 @@ mod rustls {
impl<T, S, B, X, U> H1Service<T, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
S::InitError: fmt::Debug,
S::Future: 'static,
@ -215,7 +211,7 @@ where
pub fn expect<X1>(self, expect: X1) -> H1Service<T, S, B, X1, U>
where
X1: ServiceFactory<Request = Request, Response = Request>,
X1::Error: ResponseError,
X1::Error: ResponseError + 'static,
X1::InitError: fmt::Debug,
X1::Future: 'static,
{
@ -232,8 +228,8 @@ where
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, S, B, X, U1>
where
U1: ServiceFactory<Request = (Request, Framed<T, Codec>), Response = ()>,
U1::Error: fmt::Display,
U1: ServiceFactory<Request = (Request, IoState, Codec), Response = ()>,
U1::Error: fmt::Display + Error + 'static,
U1::InitError: fmt::Debug,
U1::Future: 'static,
{
@ -262,17 +258,17 @@ impl<T, S, B, X, U> ServiceFactory for H1Service<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
S::InitError: fmt::Debug,
S::Future: 'static,
B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
U: ServiceFactory<Config = (), Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display + ResponseError,
U: ServiceFactory<Config = (), Request = (Request, IoState, Codec), Response = ()>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
{
@ -328,20 +324,20 @@ pub struct H1ServiceHandler<T, S: Service, B, X: Service, U: Service> {
impl<T, S, B, X, U> Service for H1ServiceHandler<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display + ResponseError,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, IoState, Codec), Response = ()>,
U::Error: fmt::Display + Error + 'static,
{
type Request = (T, Option<net::SocketAddr>);
type Response = ();
type Error = DispatchError;
type Future = Dispatcher<T, S, B, X, U>;
type Future = Dispatcher<S, B, X, U>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let cfg = self.config.as_ref();
@ -369,7 +365,7 @@ where
upg.poll_ready(cx)
.map_err(|e| {
log::error!("Http service readiness error: {:?}", e);
DispatchError::Service(Box::new(e))
DispatchError::Upgrade(Box::new(e))
})?
.is_ready()
&& ready
@ -407,6 +403,6 @@ where
None
};
Dispatcher::new(self.config.clone(), io, addr, on_connect)
Dispatcher::new(io, self.config.clone(), addr, on_connect)
}
}

View file

@ -1,10 +1,8 @@
use std::io;
use std::marker::PhantomData;
use std::task::{Context, Poll};
use std::{io, marker::PhantomData, task::Context, task::Poll};
use futures::future::Ready;
use crate::codec::Framed;
use crate::framed::State;
use crate::http::h1::Codec;
use crate::http::request::Request;
use crate::{Service, ServiceFactory};
@ -13,7 +11,7 @@ pub struct UpgradeHandler<T>(PhantomData<T>);
impl<T> ServiceFactory for UpgradeHandler<T> {
type Config = ();
type Request = (Request, Framed<T, Codec>);
type Request = (Request, State, Codec);
type Response = ();
type Error = io::Error;
type Service = UpgradeHandler<T>;
@ -27,7 +25,7 @@ impl<T> ServiceFactory for UpgradeHandler<T> {
}
impl<T> Service for UpgradeHandler<T> {
type Request = (Request, Framed<T, Codec>);
type Request = (Request, State, Codec);
type Response = ();
type Error = io::Error;
type Future = Ready<Result<Self::Response, Self::Error>>;

View file

@ -40,7 +40,7 @@ impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
{
@ -76,7 +76,7 @@ impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
@ -100,9 +100,7 @@ where
}
let (parts, body) = req.into_parts();
let mut req = Request::with_payload(Payload::<
crate::http::payload::PayloadStream,
>::H2(
let mut req = Request::with_payload(Payload::H2(
crate::http::h2::Payload::new(body),
));
@ -155,7 +153,7 @@ pin_project_lite::pin_project! {
impl<F, I, E, B> ServiceResponse<F, I, E, B>
where
F: Future<Output = Result<I, E>>,
E: ResponseError,
E: ResponseError + 'static,
I: Into<Response<B>>,
B: MessageBody,
{
@ -221,7 +219,7 @@ where
impl<F, I, E, B> Future for ServiceResponse<F, I, E, B>
where
F: Future<Output = Result<I, E>>,
E: ResponseError,
E: ResponseError + 'static,
I: Into<Response<B>>,
B: MessageBody,
{
@ -260,7 +258,7 @@ where
}
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
let res: Response = e.into();
let res: Response = (&e).into();
let (res, body) = res.replace_body(());
let mut send = send.take().unwrap();

View file

@ -38,7 +38,7 @@ pub struct H2Service<T, S, B> {
impl<T, S, B> H2Service<T, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
<S::Service as Service>::Future: 'static,
@ -52,7 +52,7 @@ where
H2Service {
on_connect: None,
srv: service.into_factory(),
handshake_timeout: cfg.0.ssl_handshake_timeout,
handshake_timeout: (cfg.0.ssl_handshake_timeout as u64) * 1000,
_t: PhantomData,
cfg,
}
@ -71,7 +71,7 @@ where
impl<S, B> H2Service<TcpStream, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
<S::Service as Service>::Future: 'static,
@ -108,7 +108,7 @@ mod openssl {
impl<S, B> H2Service<SslStream<TcpStream>, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
<S::Service as Service>::Future: 'static,
@ -151,7 +151,7 @@ mod rustls {
impl<S, B> H2Service<TlsStream<TcpStream>, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
<S::Service as Service>::Future: 'static,
@ -192,7 +192,7 @@ impl<T, S, B> ServiceFactory for H2Service<T, S, B>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
<S::Service as Service>::Future: 'static,
@ -236,7 +236,7 @@ impl<T, S, B> Service for H2ServiceHandler<T, S, B>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
@ -295,7 +295,7 @@ pub struct H2ServiceHandlerResponse<T, S, B>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
@ -307,7 +307,7 @@ impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody,

View file

@ -45,7 +45,3 @@ pub enum Protocol {
Http1,
Http2,
}
#[doc(hidden)]
#[deprecated(since = "0.1.19", note = "Use ntex::util::Extensions instead")]
pub use crate::util::Extensions;

View file

@ -1,45 +1,44 @@
use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem, pin::Pin, task::Context, task::Poll};
use bytes::Bytes;
use futures::Stream;
use h2::RecvStream;
use super::error::PayloadError;
use super::{h1, h2 as h2d};
/// Type represent boxed payload
pub type PayloadStream = Pin<Box<dyn Stream<Item = Result<Bytes, PayloadError>>>>;
/// Type represent streaming payload
pub enum Payload<S = PayloadStream> {
pub enum Payload {
None,
H1(crate::http::h1::Payload),
H2(crate::http::h2::Payload),
Stream(S),
H1(h1::Payload),
H2(h2d::Payload),
Stream(PayloadStream),
}
impl<S> Default for Payload<S> {
impl Default for Payload {
fn default() -> Self {
Payload::None
}
}
impl<S> From<crate::http::h1::Payload> for Payload<S> {
fn from(v: crate::http::h1::Payload) -> Self {
impl From<h1::Payload> for Payload {
fn from(v: h1::Payload) -> Self {
Payload::H1(v)
}
}
impl<S> From<crate::http::h2::Payload> for Payload<S> {
fn from(v: crate::http::h2::Payload) -> Self {
impl From<h2d::Payload> for Payload {
fn from(v: h2d::Payload) -> Self {
Payload::H2(v)
}
}
impl<S> From<RecvStream> for Payload<S> {
impl From<RecvStream> for Payload {
fn from(v: RecvStream) -> Self {
Payload::H2(crate::http::h2::Payload::new(v))
Payload::H2(h2d::Payload::new(v))
}
}
@ -49,7 +48,7 @@ impl From<PayloadStream> for Payload {
}
}
impl<S> fmt::Debug for Payload<S> {
impl fmt::Debug for Payload {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Payload::None => write!(f, "Payload::None"),
@ -60,17 +59,22 @@ impl<S> fmt::Debug for Payload<S> {
}
}
impl<S> Payload<S> {
impl Payload {
/// Takes current payload and replaces it with `None` value
pub fn take(&mut self) -> Payload<S> {
std::mem::take(self)
pub fn take(&mut self) -> Self {
mem::take(self)
}
/// Create payload from stream
pub fn from_stream<S>(stream: S) -> Self
where
S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
{
Payload::Stream(Box::pin(stream))
}
}
impl<S> Stream for Payload<S>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin,
{
impl Stream for Payload {
type Item = Result<Bytes, PayloadError>;
#[inline]
@ -93,19 +97,12 @@ mod tests {
#[test]
fn payload_debug() {
assert!(
format!("{:?}", Payload::<PayloadStream>::None).contains("Payload::None")
);
assert!(format!("{:?}", Payload::None).contains("Payload::None"));
assert!(format!("{:?}", Payload::H1(h1::Payload::create(false).1))
.contains("Payload::H1"));
assert!(format!(
"{:?}",
Payload::<PayloadStream>::H1(crate::http::h1::Payload::create(false).1)
)
.contains("Payload::H1"));
assert!(format!(
"{:?}",
Payload::<PayloadStream>::Stream(Box::pin(
crate::http::h1::Payload::create(false).1
))
Payload::Stream(Box::pin(h1::Payload::create(false).1))
)
.contains("Payload::Stream"));
}

View file

@ -1,21 +1,20 @@
use std::cell::{Ref, RefMut};
use std::{fmt, net};
use std::{cell::Ref, cell::RefMut, fmt, mem, net};
use http::{header, Method, Uri, Version};
use crate::http::header::HeaderMap;
use crate::http::httpmessage::HttpMessage;
use crate::http::message::{Message, RequestHead};
use crate::http::payload::{Payload, PayloadStream};
use crate::http::payload::Payload;
use crate::util::Extensions;
/// Request
pub struct Request<P = PayloadStream> {
pub(crate) payload: Payload<P>,
pub struct Request {
pub(crate) payload: Payload,
pub(crate) head: Message<RequestHead>,
}
impl<P> HttpMessage for Request<P> {
impl HttpMessage for Request {
#[inline]
fn message_headers(&self) -> &HeaderMap {
&self.head().headers
@ -34,7 +33,7 @@ impl<P> HttpMessage for Request<P> {
}
}
impl From<Message<RequestHead>> for Request<PayloadStream> {
impl From<Message<RequestHead>> for Request {
fn from(head: Message<RequestHead>) -> Self {
Request {
head,
@ -43,9 +42,9 @@ impl From<Message<RequestHead>> for Request<PayloadStream> {
}
}
impl Request<PayloadStream> {
impl Request {
/// Create new Request instance
pub fn new() -> Request<PayloadStream> {
pub fn new() -> Request {
Request {
head: Message::new(),
payload: Payload::None,
@ -53,9 +52,9 @@ impl Request<PayloadStream> {
}
}
impl<P> Request<P> {
impl Request {
/// Create new Request instance
pub fn with_payload(payload: Payload<P>) -> Request<P> {
pub fn with_payload(payload: Payload) -> Request {
Request {
payload,
head: Message::new(),
@ -137,25 +136,18 @@ impl<P> Request<P> {
}
/// Get request's payload
pub fn payload(&mut self) -> &mut Payload<P> {
pub fn payload(&mut self) -> &mut Payload {
&mut self.payload
}
/// Get request's payload
pub fn take_payload(&mut self) -> Payload<P> {
std::mem::take(&mut self.payload)
pub fn take_payload(&mut self) -> Payload {
mem::take(&mut self.payload)
}
/// Create new Request instance
pub fn replace_payload<P1>(self, payload: Payload<P1>) -> (Request<P1>, Payload<P>) {
let pl = self.payload;
(
Request {
payload,
head: self.head,
},
pl,
)
/// Replace request's payload, returns old one
pub fn replace_payload(&mut self, payload: Payload) -> Payload {
mem::replace(&mut self.payload, payload)
}
/// Request extensions
@ -172,12 +164,12 @@ impl<P> Request<P> {
#[allow(dead_code)]
/// Split request into request head and payload
pub(crate) fn into_parts(self) -> (Message<RequestHead>, Payload<P>) {
pub(crate) fn into_parts(self) -> (Message<RequestHead>, Payload) {
(self.head, self.payload)
}
}
impl<P> fmt::Debug for Request<P> {
impl fmt::Debug for Request {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,

View file

@ -610,7 +610,7 @@ impl ResponseBuilder {
self.body(Body::from(body))
}
Err(e) => e.into(),
Err(e) => (&e).into(),
}
}
@ -755,7 +755,7 @@ where
fn from(res: Result<I, E>) -> Self {
match res {
Ok(val) => val.into(),
Err(err) => err.into(),
Err(err) => (&err).into(),
}
}
}

View file

@ -1,10 +1,13 @@
use std::{fmt, marker::PhantomData, net, pin::Pin, rc::Rc, task::Context, task::Poll};
use std::{
error, fmt, marker::PhantomData, net, pin::Pin, rc::Rc, task::Context, task::Poll,
};
use bytes::Bytes;
use futures::future::{ok, Future, FutureExt, LocalBoxFuture};
use h2::server::{self, Handshake};
use crate::codec::{AsyncRead, AsyncWrite, Framed};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::framed::State;
use crate::rt::net::TcpStream;
use crate::service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
@ -30,7 +33,7 @@ pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>
impl<T, S, B> HttpService<T, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
@ -46,7 +49,7 @@ where
impl<T, S, B> HttpService<T, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
@ -86,7 +89,7 @@ where
impl<T, S, B, X, U> HttpService<T, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
@ -123,10 +126,10 @@ where
where
U1: ServiceFactory<
Config = (),
Request = (Request, Framed<T, h1::Codec>),
Request = (Request, State, h1::Codec),
Response = (),
>,
U1::Error: fmt::Display,
U1::Error: fmt::Display + error::Error + 'static,
U1::InitError: fmt::Debug,
U1::Future: 'static,
{
@ -153,23 +156,19 @@ where
impl<S, B, X, U> HttpService<TcpStream, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<TcpStream, h1::Codec>),
Response = (),
>,
U::Error: fmt::Display + ResponseError,
U: ServiceFactory<Config = (), Request = (Request, State, h1::Codec), Response = ()>,
U::Error: fmt::Display + error::Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
<U::Service as Service>::Future: 'static,
@ -201,23 +200,23 @@ mod openssl {
impl<S, B, X, U> HttpService<SslStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<SslStream<TcpStream>, h1::Codec>),
Request = (Request, State, h1::Codec),
Response = (),
>,
U::Error: fmt::Display + ResponseError,
U::Error: fmt::Display + error::Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
<U::Service as Service>::Future: 'static,
@ -235,7 +234,7 @@ mod openssl {
> {
pipeline_factory(
Acceptor::new(acceptor)
.timeout(self.cfg.0.ssl_handshake_timeout)
.timeout((self.cfg.0.ssl_handshake_timeout as u64) * 1000)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
@ -266,23 +265,23 @@ mod rustls {
impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<TlsStream<TcpStream>, h1::Codec>),
Request = (Request, State, h1::Codec),
Response = (),
>,
U::Error: fmt::Display + ResponseError,
U::Error: fmt::Display + error::Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
<U::Service as Service>::Future: 'static,
@ -303,7 +302,7 @@ mod rustls {
pipeline_factory(
Acceptor::new(config)
.timeout(self.cfg.0.ssl_handshake_timeout)
.timeout((self.cfg.0.ssl_handshake_timeout as u64) * 1000)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
@ -332,23 +331,19 @@ impl<T, S, B, X, U> ServiceFactory for HttpService<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, Framed<T, h1::Codec>),
Response = (),
>,
U::Error: fmt::Display + ResponseError,
U: ServiceFactory<Config = (), Request = (Request, State, h1::Codec), Response = ()>,
U::Error: fmt::Display + error::Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
<U::Service as Service>::Future: 'static,
@ -407,16 +402,16 @@ pub struct HttpServiceHandler<T, S: Service, B, X: Service, U: Service> {
impl<T, S, B, X, U> Service for HttpServiceHandler<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + 'static,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display + ResponseError,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
U::Error: fmt::Display + error::Error + 'static,
{
type Request = (T, Protocol, Option<net::SocketAddr>);
type Response = ();
@ -449,7 +444,7 @@ where
upg.poll_ready(cx)
.map_err(|e| {
log::error!("Http service readiness error: {:?}", e);
DispatchError::Service(Box::new(e))
DispatchError::Upgrade(Box::new(e))
})?
.is_ready()
&& ready
@ -489,7 +484,7 @@ where
match proto {
Protocol::Http2 => HttpServiceHandlerResponse {
state: State::H2Handshake {
state: ResponseState::H2Handshake {
data: Some((
server::handshake(io),
self.config.clone(),
@ -499,10 +494,10 @@ where
},
},
Protocol::Http1 => HttpServiceHandlerResponse {
state: State::H1 {
state: ResponseState::H1 {
fut: h1::Dispatcher::new(
self.config.clone(),
io,
self.config.clone(),
peer_addr,
on_connect,
),
@ -520,36 +515,44 @@ pin_project_lite::pin_project! {
T: Unpin,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: 'static,
S::Response: Into<Response<B>>,
S::Response: 'static,
B: MessageBody,
B: 'static,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
X::Error: 'static,
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
U::Error: fmt::Display,
U::Error: error::Error,
U::Error: 'static,
{
#[pin]
state: State<T, S, B, X, U>,
state: ResponseState<T, S, B, X, U>,
}
}
pin_project_lite::pin_project! {
#[project = StateProject]
enum State<T, S, B, X, U>
enum ResponseState<T, S, B, X, U>
where
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: 'static,
T: AsyncRead,
T: AsyncWrite,
T: Unpin,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
X::Error: 'static,
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
U::Error: fmt::Display,
U::Error: error::Error,
U::Error: 'static,
{
H1 { #[pin] fut: h1::Dispatcher<T, S, B, X, U> },
H1 { #[pin] fut: h1::Dispatcher<S, B, X, U> },
H2 { fut: Dispatcher<T, S, B, X, U> },
H2Handshake { data:
Option<(
@ -566,14 +569,14 @@ impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: ResponseError + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
U: Service<Request = (Request, Framed<T, h1::Codec>), Response = ()>,
U::Error: fmt::Display,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, State, h1::Codec), Response = ()>,
U::Error: fmt::Display + error::Error + 'static,
{
type Output = Result<(), DispatchError>;
@ -597,7 +600,7 @@ where
panic!()
};
let (_, cfg, on_connect, peer_addr) = data.take().unwrap();
self.as_mut().project().state.set(State::H2 {
self.as_mut().project().state.set(ResponseState::H2 {
fut: Dispatcher::new(cfg, conn, on_connect, None, peer_addr),
});
self.poll(cx)

View file

@ -1,11 +1,7 @@
//! Test helpers to use during testing.
use std::convert::TryFrom;
use std::str::FromStr;
use std::sync::mpsc;
use std::{io, net, thread, time};
use std::{convert::TryFrom, io, net, str::FromStr, sync::mpsc, thread, time};
use bytes::Bytes;
use futures::Stream;
#[cfg(feature = "cookie")]
use coo_kie::{Cookie, CookieJar};
@ -316,13 +312,11 @@ impl TestServer {
.request(method, self.surl(path.as_ref()).as_str())
}
pub async fn load_body<S>(
/// Load response's body
pub async fn load_body(
&mut self,
mut response: ClientResponse<S>,
) -> Result<Bytes, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static,
{
mut response: ClientResponse,
) -> Result<Bytes, PayloadError> {
response.body().limit(10_485_760).await
}

View file

@ -189,7 +189,7 @@ mod tests {
crate::rt::spawn(disp.map(|_| ()));
let mut buf = BytesMut::new();
let mut codec = ws::Codec::new().client_mode();
let codec = ws::Codec::new().client_mode();
codec
.encode(ws::Message::Text("test".to_string()), &mut buf)
.unwrap();

View file

@ -5,12 +5,13 @@ use std::rc::Rc;
use futures::future::{FutureExt, LocalBoxFuture};
use crate::http::{Extensions, Request};
use crate::http::Request;
use crate::router::ResourceDef;
use crate::service::boxed::{self, BoxServiceFactory};
use crate::service::{
apply, apply_fn_factory, IntoServiceFactory, ServiceFactory, Transform,
};
use crate::util::Extensions;
use super::app_service::{AppEntry, AppFactory, AppRoutingFactory};
use super::config::{AppConfig, ServiceConfig};

View file

@ -2,9 +2,10 @@ use std::{cell::RefCell, marker::PhantomData, rc::Rc, task::Context, task::Poll}
use futures::future::{ok, FutureExt, LocalBoxFuture};
use crate::http::{Extensions, Request, Response};
use crate::http::{Request, Response};
use crate::router::{Path, ResourceDef, ResourceInfo, Router};
use crate::service::boxed::{self, BoxService, BoxServiceFactory};
use crate::util::Extensions;
use crate::{fn_service, Service, ServiceFactory};
use super::config::AppConfig;

View file

@ -5,10 +5,10 @@ use std::{fmt, net};
use futures::future::{ok, Ready};
use crate::http::{
Extensions, HeaderMap, HttpMessage, Message, Method, Payload, RequestHead, Uri,
Version,
HeaderMap, HttpMessage, Message, Method, Payload, RequestHead, Uri, Version,
};
use crate::router::Path;
use crate::util::Extensions;
use super::config::AppConfig;
use super::error::{ErrorRenderer, UrlGenerationError};

View file

@ -4,10 +4,10 @@ use std::rc::Rc;
use std::{fmt, net};
use crate::http::{
header, Extensions, HeaderMap, HttpMessage, Method, Payload, PayloadStream,
RequestHead, Response, Uri, Version,
header, HeaderMap, HttpMessage, Method, Payload, RequestHead, Response, Uri, Version,
};
use crate::router::{Path, Resource};
use crate::util::Extensions;
use super::config::AppConfig;
use super::error::{ErrorRenderer, WebResponseError};
@ -204,7 +204,7 @@ impl<Err> WebRequest<Err> {
#[inline]
/// Get request's payload
pub fn take_payload(&mut self) -> Payload<PayloadStream> {
pub fn take_payload(&mut self) -> Payload {
Rc::get_mut(&mut (self.req).0).unwrap().payload.take()
}

View file

@ -2,12 +2,13 @@ use std::{cell::RefCell, fmt, rc::Rc, task::Context, task::Poll};
use futures::future::{ok, Either, Future, FutureExt, LocalBoxFuture, Ready};
use crate::http::{Extensions, Response};
use crate::http::Response;
use crate::router::{IntoPattern, ResourceDef};
use crate::service::boxed::{self, BoxService, BoxServiceFactory};
use crate::service::{
apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform,
};
use crate::util::Extensions;
use super::dev::{insert_slesh, WebServiceConfig, WebServiceFactory};
use super::error::ErrorRenderer;

View file

@ -2,12 +2,13 @@ use std::{cell::RefCell, fmt, rc::Rc, task::Context, task::Poll};
use futures::future::{ok, Either, Future, FutureExt, LocalBoxFuture, Ready};
use crate::http::{Extensions, Response};
use crate::http::Response;
use crate::router::{ResourceDef, ResourceInfo, Router};
use crate::service::boxed::{self, BoxService, BoxServiceFactory};
use crate::service::{
apply, apply_fn_factory, IntoServiceFactory, Service, ServiceFactory, Transform,
};
use crate::util::Extensions;
use super::config::ServiceConfig;
use super::dev::{WebServiceConfig, WebServiceFactory};

View file

@ -24,9 +24,9 @@ use super::config::AppConfig;
struct Config {
host: Option<String>,
keep_alive: KeepAlive,
client_timeout: u64,
client_disconnect: u64,
handshake_timeout: u64,
client_timeout: u16,
client_disconnect: u16,
handshake_timeout: u16,
}
/// An HTTP Server.
@ -148,7 +148,7 @@ where
self
}
/// Set server client timeout in milliseconds for first request.
/// Set server client timeout in seconds for first request.
///
/// Defines a timeout for reading client request header. If a client does not transmit
/// the entire set headers within this time, the request is terminated with
@ -157,12 +157,12 @@ where
/// To disable timeout set value to 0.
///
/// By default client timeout is set to 5 seconds.
pub fn client_timeout(self, val: u64) -> Self {
pub fn client_timeout(self, val: u16) -> Self {
self.config.lock().unwrap().client_timeout = val;
self
}
/// Set server connection disconnect timeout in milliseconds.
/// Set server connection disconnect timeout in seconds.
///
/// Defines a timeout for shutdown connection. If a shutdown procedure does not complete
/// within this time, the request is dropped.
@ -170,18 +170,18 @@ where
/// To disable timeout set value to 0.
///
/// By default client timeout is set to 5 seconds.
pub fn disconnect_timeout(self, val: u64) -> Self {
pub fn disconnect_timeout(self, val: u16) -> Self {
self.config.lock().unwrap().client_disconnect = val;
self
}
/// Set server ssl handshake timeout in milliseconds.
/// Set server ssl handshake timeout in seconds.
///
/// Defines a timeout for connection ssl handshake negotiation.
/// To disable timeout set value to 0.
///
/// By default handshake timeout is set to 5 seconds.
pub fn ssl_handshake_timeout(self, val: u64) -> Self {
pub fn ssl_handshake_timeout(self, val: u16) -> Self {
self.config.lock().unwrap().handshake_timeout = val;
self
}

View file

@ -1,8 +1,8 @@
use std::rc::Rc;
use crate::http::Extensions;
use crate::router::{IntoPattern, ResourceDef};
use crate::service::{boxed, IntoServiceFactory, ServiceFactory};
use crate::util::Extensions;
use super::config::AppConfig;
use super::dev::insert_slesh;

View file

@ -1,10 +1,8 @@
//! Various helpers for ntex applications to use during testing.
use std::convert::TryFrom;
use std::error::Error;
use std::net::SocketAddr;
use std::rc::Rc;
use std::sync::mpsc;
use std::{fmt, net, thread, time};
use std::{
convert::TryFrom, error::Error, fmt, net, net::SocketAddr, rc::Rc, sync::mpsc,
thread, time,
};
use bytes::{Bytes, BytesMut};
use futures::future::ok;
@ -22,12 +20,11 @@ use crate::http::client::{Client, ClientRequest, ClientResponse, Connector};
use crate::http::error::{HttpError, PayloadError, ResponseError};
use crate::http::header::{HeaderName, HeaderValue, CONTENT_TYPE};
use crate::http::test::TestRequest as HttpTestRequest;
use crate::http::{
Extensions, HttpService, Method, Payload, Request, StatusCode, Uri, Version,
};
use crate::http::{HttpService, Method, Payload, Request, StatusCode, Uri, Version};
use crate::router::{Path, ResourceDef};
use crate::rt::{time::delay_for, System};
use crate::server::Server;
use crate::util::Extensions;
use crate::{map_config, IntoService, IntoServiceFactory, Service, ServiceFactory};
use crate::web::config::AppConfig;
@ -776,7 +773,7 @@ where
pub struct TestServerConfig {
tp: HttpVer,
stream: StreamType,
client_timeout: u64,
client_timeout: u16,
}
#[derive(Clone, Debug)]
@ -854,8 +851,8 @@ impl TestServerConfig {
self
}
/// Set server client timeout in milliseconds for first request.
pub fn client_timeout(mut self, val: u64) -> Self {
/// Set server client timeout in seconds for first request.
pub fn client_timeout(mut self, val: u16) -> Self {
self.client_timeout = val;
self
}
@ -927,13 +924,11 @@ impl TestServer {
self.client.request(method, path.as_ref())
}
pub async fn load_body<S>(
/// Load response's body
pub async fn load_body(
&self,
mut response: ClientResponse<S>,
) -> Result<Bytes, PayloadError>
where
S: Stream<Item = Result<Bytes, PayloadError>> + Unpin + 'static,
{
mut response: ClientResponse,
) -> Result<Bytes, PayloadError> {
response.body().limit(10_485_760).await
}

View file

@ -1,9 +1,9 @@
use std::ops::Deref;
use std::sync::Arc;
use std::{ops::Deref, sync::Arc};
use futures::future::{err, ok, Ready};
use crate::http::{Extensions, Payload};
use crate::http::Payload;
use crate::util::Extensions;
use crate::web::error::{DataExtractorError, ErrorRenderer};
use crate::web::extract::FromRequest;
use crate::web::httprequest::HttpRequest;

View file

@ -1,4 +1,5 @@
use bytes::{Bytes, BytesMut};
use std::cell::Cell;
use crate::codec::{Decoder, Encoder};
@ -49,10 +50,10 @@ pub enum Item {
Last(Bytes),
}
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Clone)]
/// WebSockets protocol codec
pub struct Codec {
flags: Flags,
flags: Cell<Flags>,
max_size: usize,
}
@ -69,7 +70,7 @@ impl Codec {
pub fn new() -> Codec {
Codec {
max_size: 65_536,
flags: Flags::SERVER,
flags: Cell::new(Flags::SERVER),
}
}
@ -84,10 +85,22 @@ impl Codec {
/// Set decoder to client mode.
///
/// By default decoder works in server mode.
pub fn client_mode(mut self) -> Self {
self.flags.remove(Flags::SERVER);
pub fn client_mode(self) -> Self {
self.remove_flags(Flags::SERVER);
self
}
fn insert_flags(&self, f: Flags) {
let mut flags = self.flags.get();
flags.insert(f);
self.flags.set(flags);
}
fn remove_flags(&self, f: Flags) {
let mut flags = self.flags.get();
flags.remove(f);
self.flags.set(flags);
}
}
impl Default for Codec {
@ -100,90 +113,92 @@ impl Encoder for Codec {
type Item = Message;
type Error = ProtocolError;
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
fn encode(&self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Text(txt) => Parser::write_message(
dst,
txt,
OpCode::Text,
true,
!self.flags.contains(Flags::SERVER),
!self.flags.get().contains(Flags::SERVER),
),
Message::Binary(bin) => Parser::write_message(
dst,
bin,
OpCode::Binary,
true,
!self.flags.contains(Flags::SERVER),
!self.flags.get().contains(Flags::SERVER),
),
Message::Ping(txt) => Parser::write_message(
dst,
txt,
OpCode::Ping,
true,
!self.flags.contains(Flags::SERVER),
!self.flags.get().contains(Flags::SERVER),
),
Message::Pong(txt) => Parser::write_message(
dst,
txt,
OpCode::Pong,
true,
!self.flags.contains(Flags::SERVER),
!self.flags.get().contains(Flags::SERVER),
),
Message::Close(reason) => Parser::write_close(
dst,
reason,
!self.flags.get().contains(Flags::SERVER),
),
Message::Close(reason) => {
Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER))
}
Message::Continuation(cont) => match cont {
Item::FirstText(data) => {
if self.flags.contains(Flags::W_CONTINUATION) {
if self.flags.get().contains(Flags::W_CONTINUATION) {
return Err(ProtocolError::ContinuationStarted);
} else {
self.flags.insert(Flags::W_CONTINUATION);
self.insert_flags(Flags::W_CONTINUATION);
Parser::write_message(
dst,
&data[..],
OpCode::Text,
false,
!self.flags.contains(Flags::SERVER),
!self.flags.get().contains(Flags::SERVER),
)
}
}
Item::FirstBinary(data) => {
if self.flags.contains(Flags::W_CONTINUATION) {
if self.flags.get().contains(Flags::W_CONTINUATION) {
return Err(ProtocolError::ContinuationStarted);
} else {
self.flags.insert(Flags::W_CONTINUATION);
self.insert_flags(Flags::W_CONTINUATION);
Parser::write_message(
dst,
&data[..],
OpCode::Binary,
false,
!self.flags.contains(Flags::SERVER),
!self.flags.get().contains(Flags::SERVER),
)
}
}
Item::Continue(data) => {
if self.flags.contains(Flags::W_CONTINUATION) {
if self.flags.get().contains(Flags::W_CONTINUATION) {
Parser::write_message(
dst,
&data[..],
OpCode::Continue,
false,
!self.flags.contains(Flags::SERVER),
!self.flags.get().contains(Flags::SERVER),
)
} else {
return Err(ProtocolError::ContinuationNotStarted);
}
}
Item::Last(data) => {
if self.flags.contains(Flags::W_CONTINUATION) {
self.flags.remove(Flags::W_CONTINUATION);
if self.flags.get().contains(Flags::W_CONTINUATION) {
self.remove_flags(Flags::W_CONTINUATION);
Parser::write_message(
dst,
&data[..],
OpCode::Continue,
true,
!self.flags.contains(Flags::SERVER),
!self.flags.get().contains(Flags::SERVER),
)
} else {
return Err(ProtocolError::ContinuationNotStarted);
@ -199,14 +214,15 @@ impl Decoder for Codec {
type Item = Frame;
type Error = ProtocolError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) {
fn decode(&self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match Parser::parse(src, self.flags.get().contains(Flags::SERVER), self.max_size)
{
Ok(Some((finished, opcode, payload))) => {
// handle continuation
if !finished {
return match opcode {
OpCode::Continue => {
if self.flags.contains(Flags::CONTINUATION) {
if self.flags.get().contains(Flags::CONTINUATION) {
Ok(Some(Frame::Continuation(Item::Continue(
payload
.map(|pl| pl.freeze())
@ -217,8 +233,8 @@ impl Decoder for Codec {
}
}
OpCode::Binary => {
if !self.flags.contains(Flags::CONTINUATION) {
self.flags.insert(Flags::CONTINUATION);
if !self.flags.get().contains(Flags::CONTINUATION) {
self.insert_flags(Flags::CONTINUATION);
Ok(Some(Frame::Continuation(Item::FirstBinary(
payload
.map(|pl| pl.freeze())
@ -229,8 +245,8 @@ impl Decoder for Codec {
}
}
OpCode::Text => {
if !self.flags.contains(Flags::CONTINUATION) {
self.flags.insert(Flags::CONTINUATION);
if !self.flags.get().contains(Flags::CONTINUATION) {
self.insert_flags(Flags::CONTINUATION);
Ok(Some(Frame::Continuation(Item::FirstText(
payload
.map(|pl| pl.freeze())
@ -249,8 +265,8 @@ impl Decoder for Codec {
match opcode {
OpCode::Continue => {
if self.flags.contains(Flags::CONTINUATION) {
self.flags.remove(Flags::CONTINUATION);
if self.flags.get().contains(Flags::CONTINUATION) {
self.remove_flags(Flags::CONTINUATION);
Ok(Some(Frame::Continuation(Item::Last(
payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new),
))))

View file

@ -187,7 +187,7 @@ mod tests {
let mut decoder = StreamDecoder::new(rx);
let mut buf = BytesMut::new();
let mut codec = Codec::new().client_mode();
let codec = Codec::new().client_mode();
codec
.encode(Message::Text("test1".to_string()), &mut buf)
.unwrap();

View file

@ -4,41 +4,58 @@ use bytes::Bytes;
use futures::future::ok;
use futures::{SinkExt, StreamExt};
use ntex::codec::Framed;
use ntex::framed::{DispatchItem, Dispatcher, State};
use ntex::http::test::server as test_server;
use ntex::http::ws::handshake_response;
use ntex::http::{body::BodySize, h1, HttpService, Request, Response};
use ntex::util::framed::Dispatcher;
use ntex::ws;
async fn ws_service(req: ws::Frame) -> Result<Option<ws::Message>, io::Error> {
let item = match req {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8(Vec::from(text.as_ref())).unwrap())
}
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => ws::Message::Close(None),
async fn ws_service(
msg: DispatchItem<ws::Codec>,
) -> Result<Option<ws::Message>, io::Error> {
println!("TEST: {:?}", msg);
let msg = match msg {
DispatchItem::Item(msg) => match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8(Vec::from(text.as_ref())).unwrap())
}
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => ws::Message::Close(None),
},
_ => return Ok(None),
};
Ok(Some(item))
Ok(Some(msg))
}
#[ntex::test]
async fn test_simple() {
std::env::set_var("RUST_LOG", "ntex_codec=info,ntex=trace");
env_logger::init();
let mut srv = test_server(|| {
HttpService::build()
.upgrade(|(req, mut framed): (Request, Framed<_, _>)| {
.upgrade(|(req, state, mut codec): (Request, State, h1::Codec)| {
async move {
let res = handshake_response(req.head()).finish();
// send handshake response
framed
.send(h1::Message::Item((res.drop_body(), BodySize::None)))
.await?;
// send handshake respone
state
.write_item(
h1::Message::Item((res.drop_body(), BodySize::None)),
&mut codec,
)
.unwrap();
// start websocket service
let framed = framed.into_framed(ws::Codec::default());
Dispatcher::new(framed, ws_service).await
Dispatcher::from_state(
ws::Codec::default(),
state,
ws_service,
Default::default(),
)
.await
}
})
.finish(|_| ok::<_, io::Error>(Response::NotFound()))

View file

@ -114,7 +114,7 @@ async fn test_h2_body() -> io::Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let mut srv = test_server(move || {
HttpService::build()
.h2(|mut req: Request<_>| async move {
.h2(|mut req: Request| async move {
let body = load_body(req.take_payload())
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
@ -443,7 +443,7 @@ async fn test_ssl_handshake_timeout() {
let srv = test_server(move || {
HttpService::build()
.ssl_handshake_timeout(50)
.ssl_handshake_timeout(1)
.h2(|_| ok::<_, io::Error>(Response::Ok().finish()))
.openssl(ssl_acceptor())
.map_err(|_| ())

View file

@ -105,7 +105,7 @@ async fn test_h2_body1() -> io::Result<()> {
let data = "HELLOWORLD".to_owned().repeat(64 * 1024);
let mut srv = test_server(move || {
HttpService::build()
.h2(|mut req: Request<_>| async move {
.h2(|mut req: Request| async move {
let body = load_body(req.take_payload())
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
@ -446,7 +446,7 @@ async fn test_ssl_handshake_timeout() {
let srv = test_server(move || {
HttpService::build()
.ssl_handshake_timeout(50)
.ssl_handshake_timeout(1)
.h2(|_| ok::<_, io::Error>(Response::Ok().finish()))
.rustls(ssl_acceptor())
});

View file

@ -178,7 +178,7 @@ async fn test_chunked_payload() {
async fn test_slow_request() {
let srv = test_server(|| {
HttpService::build()
.client_timeout(100)
.client_timeout(1)
.finish(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
.tcp()
});

View file

@ -1,47 +1,39 @@
use std::cell::Cell;
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::{cell::Cell, io, pin::Pin};
use bytes::Bytes;
use futures::{future, Future, SinkExt, StreamExt};
use ntex::codec::{AsyncRead, AsyncWrite, Framed};
use ntex::http::ws::handshake;
use ntex::http::{body, h1, test, HttpService, Request, Response};
use ntex::framed::{DispatchItem, Dispatcher, State, Timer};
use ntex::http::{body, h1, test, ws::handshake, HttpService, Request, Response};
use ntex::service::{fn_factory, Service};
use ntex::util::framed::Dispatcher;
use ntex::ws;
struct WsService<T>(Arc<Mutex<(PhantomData<T>, Cell<bool>)>>);
struct WsService(Arc<Mutex<Cell<bool>>>);
impl<T> WsService<T> {
impl WsService {
fn new() -> Self {
WsService(Arc::new(Mutex::new((PhantomData, Cell::new(false)))))
WsService(Arc::new(Mutex::new(Cell::new(false))))
}
fn set_polled(&self) {
*self.0.lock().unwrap().1.get_mut() = true;
*self.0.lock().unwrap().get_mut() = true;
}
fn was_polled(&self) -> bool {
self.0.lock().unwrap().1.get()
self.0.lock().unwrap().get()
}
}
impl<T> Clone for WsService<T> {
impl Clone for WsService {
fn clone(&self) -> Self {
WsService(self.0.clone())
}
}
impl<T> Service for WsService<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Request = (Request, Framed<T, h1::Codec>);
impl Service for WsService {
type Request = (Request, State, h1::Codec);
type Response = ();
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<(), io::Error>>>>;
@ -51,16 +43,15 @@ where
Poll::Ready(Ok(()))
}
fn call(&self, (req, mut framed): Self::Request) -> Self::Future {
fn call(&self, (req, state, mut codec): Self::Request) -> Self::Future {
let fut = async move {
let res = handshake(req.head()).unwrap().message_body(());
framed
.send((res, body::BodySize::None).into())
.await
state
.write_item((res, body::BodySize::None).into(), &mut codec)
.unwrap();
Dispatcher::new(framed.into_framed(ws::Codec::new()), service)
Dispatcher::from_state(ws::Codec::new(), state, service, Timer::default())
.await
.map_err(|_| panic!())
};
@ -69,16 +60,21 @@ where
}
}
async fn service(msg: ws::Frame) -> Result<Option<ws::Message>, io::Error> {
async fn service(
msg: DispatchItem<ws::Codec>,
) -> Result<Option<ws::Message>, io::Error> {
let msg = match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8_lossy(&text).to_string())
}
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
ws::Frame::Continuation(item) => ws::Message::Continuation(item),
ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(),
DispatchItem::Item(msg) => match msg {
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
ws::Frame::Text(text) => {
ws::Message::Text(String::from_utf8_lossy(&text).to_string())
}
ws::Frame::Binary(bin) => ws::Message::Binary(bin),
ws::Frame::Continuation(item) => ws::Message::Continuation(item),
ws::Frame::Close(reason) => ws::Message::Close(reason),
_ => panic!(),
},
_ => return Ok(None),
};
Ok(Some(msg))
}

View file

@ -1054,7 +1054,7 @@ async fn test_server_cookies() {
async fn test_slow_request() {
use std::net;
let srv = test::server_with(test::config().client_timeout(200), || {
let srv = test::server_with(test::config().client_timeout(1), || {
App::new()
.service(web::resource("/").route(web::to(|| async { HttpResponse::Ok() })))
});