use ntex-io instead of framed

This commit is contained in:
Nikolay Kim 2021-12-15 14:09:36 +06:00
parent dafd339817
commit 3dbba47ab1
62 changed files with 1545 additions and 5639 deletions

View file

@ -1,8 +1,14 @@
# Changes
## [0.6.0] - 2021-12-xx
* Removed Framed type
* Removed tokio dependency
## [0.5.1] - 2021-09-08
* Fix tight loop in Framed::close() method.
* Fix tight loop in Framed::close() method
## [0.5.0] - 2021-06-27

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-codec"
version = "0.5.1"
version = "0.6.0"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
@ -16,12 +16,4 @@ name = "ntex_codec"
path = "src/lib.rs"
[dependencies]
bitflags = "1.3"
ntex-bytes = "0.1"
ntex-util = "0.1"
log = "0.4"
tokio = { version = "1", default-features = false }
[dev-dependencies]
ntex = "0.4.13"
futures = "0.3.13"

View file

@ -1,691 +0,0 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io};
use ntex_bytes::{Buf, BytesMut};
use ntex_util::{future::Either, ready, Sink, Stream};
use crate::{AsyncRead, AsyncWrite, Decoder, Encoder};
const LW: usize = 1024;
const HW: usize = 8 * 1024;
bitflags::bitflags! {
struct Flags: u8 {
const EOF = 0b0001;
const READABLE = 0b0010;
const DISCONNECTED = 0b0100;
const SHUTDOWN = 0b1000;
}
}
/// A unified interface to an underlying I/O object, using
/// the `Encoder` and `Decoder` traits to encode and decode frames.
/// `Framed` is heavily optimized for streaming io.
pub struct Framed<T, U> {
io: T,
codec: U,
flags: Flags,
read_buf: BytesMut,
write_buf: BytesMut,
err: Option<io::Error>,
}
impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite,
U: Decoder + Encoder,
{
#[inline]
/// Provides an interface for reading and writing to
/// `Io` object, using `Decode` and `Encode` traits of codec.
///
/// Raw I/O objects work with byte sequences, but higher-level code usually
/// wants to batch these into meaningful chunks, called "frames". This
/// method layers framing on top of an I/O object, by using the `Codec`
/// traits to handle encoding and decoding of messages frames. Note that
/// the incoming and outgoing frame types may be distinct.
pub fn new(io: T, codec: U) -> Framed<T, U> {
Framed {
io,
codec,
err: None,
flags: Flags::empty(),
read_buf: BytesMut::with_capacity(HW),
write_buf: BytesMut::with_capacity(HW),
}
}
}
impl<T, U> Framed<T, U> {
#[inline]
/// Construct `Framed` object `parts`.
pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> {
Framed {
io: parts.io,
codec: parts.codec,
flags: parts.flags,
write_buf: parts.write_buf,
read_buf: parts.read_buf,
err: parts.err,
}
}
#[inline]
/// Returns a reference to the underlying codec.
pub fn get_codec(&self) -> &U {
&self.codec
}
#[inline]
/// Returns a mutable reference to the underlying codec.
pub fn get_codec_mut(&mut self) -> &mut U {
&mut self.codec
}
#[inline]
/// Returns a reference to the underlying I/O stream wrapped by `Framed`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_ref(&self) -> &T {
&self.io
}
#[inline]
/// Returns a mutable reference to the underlying I/O stream wrapped by
/// `Framed`.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn get_mut(&mut self) -> &mut T {
&mut self.io
}
#[inline]
/// Get read buffer.
pub fn read_buf(&mut self) -> &mut BytesMut {
&mut self.read_buf
}
#[inline]
/// Get write buffer.
pub fn write_buf(&mut self) -> &mut BytesMut {
&mut self.write_buf
}
#[inline]
/// Check if write buffer is empty.
pub fn is_write_buf_empty(&self) -> bool {
self.write_buf.is_empty()
}
#[inline]
/// Check if write buffer is full.
pub fn is_write_buf_full(&self) -> bool {
self.write_buf.len() >= HW
}
#[inline]
/// Check if framed object is closed
pub fn is_closed(&self) -> bool {
self.flags.contains(Flags::DISCONNECTED)
}
#[inline]
/// Consume the `Frame`, returning `Frame` with different codec.
pub fn into_framed<U2>(self, codec: U2) -> Framed<T, U2> {
Framed {
codec,
io: self.io,
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
err: self.err,
}
}
#[inline]
/// Consume the `Frame`, returning `Frame` with different io.
pub fn map_io<F, T2>(self, f: F) -> Framed<T2, U>
where
F: Fn(T) -> T2,
{
Framed {
io: f(self.io),
codec: self.codec,
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
err: self.err,
}
}
#[inline]
/// Consume the `Frame`, returning `Frame` with different codec.
pub fn map_codec<F, U2>(self, f: F) -> Framed<T, U2>
where
F: Fn(U) -> U2,
{
Framed {
io: self.io,
codec: f(self.codec),
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
err: self.err,
}
}
#[inline]
/// Consumes the `Frame`, returning its underlying I/O stream, the buffer
/// with unprocessed data, and the codec.
///
/// Note that care should be taken to not tamper with the underlying stream
/// of data coming in as it may corrupt the stream of frames otherwise
/// being worked with.
pub fn into_parts(self) -> FramedParts<T, U> {
FramedParts {
io: self.io,
codec: self.codec,
flags: self.flags,
read_buf: self.read_buf,
write_buf: self.write_buf,
err: self.err,
}
}
}
impl<T, U> Framed<T, U>
where
T: AsyncWrite + Unpin,
U: Encoder,
{
#[inline]
/// Serialize item and Write to the inner buffer
pub fn write(
&mut self,
item: <U as Encoder>::Item,
) -> Result<(), <U as Encoder>::Error> {
let remaining = self.write_buf.capacity() - self.write_buf.len();
if remaining < LW {
self.write_buf.reserve(HW - remaining);
}
self.codec.encode(item, &mut self.write_buf)?;
Ok(())
}
#[inline]
/// Check if framed is able to write more data.
///
/// `Framed` object considers ready if there is free space in write buffer.
pub fn is_write_ready(&self) -> bool {
self.write_buf.len() < HW
}
/// Flush write buffer to underlying I/O stream.
pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
log::trace!("flushing framed transport");
let len = self.write_buf.len();
if len != 0 {
let mut written = 0;
while written < len {
match Pin::new(&mut self.io).poll_write(cx, &self.write_buf[written..]) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!(
"Disconnected during flush, written {}",
written
);
self.flags.insert(Flags::DISCONNECTED);
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
} else {
written += n
}
}
Poll::Ready(Err(e)) => {
log::trace!("Error during flush: {}", e);
self.flags.insert(Flags::DISCONNECTED);
return Poll::Ready(Err(e));
}
}
}
log::trace!("flushed {} bytes", written);
// remove written data
if written == len {
self.write_buf.clear()
} else {
self.write_buf.advance(written);
}
}
// flush
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
if self.write_buf.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
/// Flush write buffer and shutdown underlying I/O stream.
///
/// Close method shutdown write side of a io object and
/// then reads until disconnect or error, high level code must use
/// timeout for close operation.
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if !self.flags.contains(Flags::DISCONNECTED) {
// flush write buffer
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
if !self.flags.contains(Flags::SHUTDOWN) {
// shutdown WRITE side
ready!(Pin::new(&mut self.io).poll_shutdown(cx)).map_err(|e| {
self.flags.insert(Flags::DISCONNECTED);
e
})?;
self.flags.insert(Flags::SHUTDOWN);
}
// read until 0 or err
let mut buf = [0u8; 512];
loop {
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
match ready!(Pin::new(&mut self.io).poll_read(cx, &mut read_buf)) {
Err(_) | Ok(_) if read_buf.filled().is_empty() => {
break;
}
_ => (),
}
}
self.flags.insert(Flags::DISCONNECTED);
}
log::trace!("framed transport flushed and closed");
Poll::Ready(Ok(()))
}
}
pub type ItemType<U> =
Result<<U as Decoder>::Item, Either<<U as Decoder>::Error, io::Error>>;
impl<T, U> Framed<T, U>
where
T: AsyncRead + Unpin,
U: Decoder,
{
/// Try to read underlying I/O stream and decode item.
pub fn next_item(&mut self, cx: &mut Context<'_>) -> Poll<Option<ItemType<U>>> {
let mut done_read = false;
loop {
// Repeatedly call `decode` or `decode_eof` as long as it is
// "readable". Readable is defined as not having returned `None`. If
// the upstream has returned EOF, and the decoder is no longer
// readable, it can be assumed that the decoder will never become
// readable again, at which point the stream is terminated.
if self.flags.contains(Flags::READABLE) {
if self.flags.contains(Flags::EOF) {
return match self.codec.decode_eof(&mut self.read_buf) {
Ok(Some(frame)) => Poll::Ready(Some(Ok(frame))),
Ok(None) => {
if let Some(err) = self.err.take() {
Poll::Ready(Some(Err(Either::Right(err))))
} else if !self.read_buf.is_empty() {
Poll::Ready(Some(Err(Either::Right(io::Error::new(
io::ErrorKind::Other,
"bytes remaining on stream",
)))))
} else {
Poll::Ready(None)
}
}
Err(e) => return Poll::Ready(Some(Err(Either::Left(e)))),
};
}
log::trace!("attempting to decode a frame");
match self.codec.decode(&mut self.read_buf) {
Ok(Some(frame)) => {
log::trace!("frame decoded from buffer");
return Poll::Ready(Some(Ok(frame)));
}
Err(e) => return Poll::Ready(Some(Err(Either::Left(e)))),
_ => (), // Need more data
}
self.flags.remove(Flags::READABLE);
if done_read {
return Poll::Pending;
}
}
debug_assert!(!self.flags.contains(Flags::EOF));
// read all data from socket
let mut updated = false;
loop {
// Otherwise, try to read more data and try again. Make sure we've got room
let remaining = self.read_buf.capacity() - self.read_buf.len();
if remaining < LW {
self.read_buf.reserve(HW - remaining)
}
match crate::poll_read_buf(
Pin::new(&mut self.io),
cx,
&mut self.read_buf,
) {
Poll::Pending => {
if updated {
done_read = true;
self.flags.insert(Flags::READABLE);
break;
} else {
return Poll::Pending;
}
}
Poll::Ready(Ok(n)) => {
if n == 0 {
self.flags.insert(Flags::EOF | Flags::READABLE);
if updated {
done_read = true;
}
break;
} else {
updated = true;
}
}
Poll::Ready(Err(e)) => {
if updated {
done_read = true;
self.err = Some(e);
self.flags.insert(Flags::EOF | Flags::READABLE);
break;
} else {
return Poll::Ready(Some(Err(Either::Right(e))));
}
}
}
}
}
}
}
impl<T, U> Stream for Framed<T, U>
where
T: AsyncRead + Unpin,
U: Decoder + Unpin,
{
type Item = Result<U::Item, Either<U::Error, io::Error>>;
#[inline]
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.next_item(cx)
}
}
impl<T, U> Sink<U::Item> for Framed<T, U>
where
T: AsyncRead + AsyncWrite + Unpin,
U: Encoder + Unpin,
{
type Error = Either<U::Error, io::Error>;
#[inline]
fn poll_ready(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
if self.is_write_ready() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
#[inline]
fn start_send(
mut self: Pin<&mut Self>,
item: <U as Encoder>::Item,
) -> Result<(), Self::Error> {
self.write(item).map_err(Either::Left)
}
#[inline]
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.flush(cx).map_err(Either::Right)
}
#[inline]
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.close(cx).map_err(Either::Right)
}
}
impl<T, U> fmt::Debug for Framed<T, U>
where
T: fmt::Debug,
U: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Framed")
.field("io", &self.io)
.field("codec", &self.codec)
.finish()
}
}
/// `FramedParts` contains an export of the data of a Framed transport.
/// It can be used to construct a new `Framed` with a different codec.
/// It contains all current buffers and the inner transport.
#[derive(Debug)]
pub struct FramedParts<T, U> {
/// The inner transport used to read bytes to and write bytes to
pub io: T,
/// The codec
pub codec: U,
/// The buffer with read but unprocessed data.
pub read_buf: BytesMut,
/// A buffer with unprocessed data which are not written yet.
pub write_buf: BytesMut,
flags: Flags,
err: Option<io::Error>,
}
impl<T, U> FramedParts<T, U> {
/// Create a new, default, `FramedParts`
pub fn new(io: T, codec: U) -> FramedParts<T, U> {
FramedParts {
io,
codec,
err: None,
flags: Flags::empty(),
read_buf: BytesMut::new(),
write_buf: BytesMut::new(),
}
}
/// Create a new `FramedParts` with read buffer
pub fn with_read_buf(io: T, codec: U, read_buf: BytesMut) -> FramedParts<T, U> {
FramedParts {
io,
codec,
read_buf,
err: None,
flags: Flags::empty(),
write_buf: BytesMut::new(),
}
}
}
#[cfg(test)]
mod tests {
use futures::{future::lazy, Sink};
use ntex::testing::Io;
use ntex_bytes::Bytes;
use super::*;
use crate::BytesCodec;
#[ntex::test]
async fn test_basics() {
let (_, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);
server.get_codec_mut();
server.get_ref();
server.get_mut();
let parts = server.into_parts();
let server = Framed::from_parts(FramedParts::new(parts.io, parts.codec));
assert!(format!("{:?}", server).contains("Framed"));
}
#[ntex::test]
async fn test_sink() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
let mut server = Framed::new(server, BytesCodec);
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
.await
.is_ready());
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
Pin::new(&mut server).start_send(data).unwrap();
assert_eq!(client.read_any(), b"".as_ref());
assert_eq!(server.read_buf(), b"".as_ref());
assert_eq!(server.write_buf(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_ready());
assert_eq!(client.read_any(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_pending());
client.close().await;
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_ready());
assert!(client.is_closed());
}
#[ntex::test]
async fn test_write_pending() {
let (client, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
.await
.is_ready());
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
Pin::new(&mut server).start_send(data).unwrap();
client.remote_buffer_cap(3);
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_pending());
assert_eq!(client.read_any(), b"GET".as_ref());
client.remote_buffer_cap(1024);
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_ready());
assert_eq!(client.read_any(), b" /test HTTP/1.1\r\n\r\n".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_pending());
client.close().await;
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_ready());
assert!(client.is_closed());
assert!(server.is_closed());
}
#[ntex::test]
async fn test_read_pending() {
let (client, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);
client.read_pending();
assert!(lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.is_pending());
client.write(b"GET /test HTTP/1.1\r\n\r\n");
client.close().await;
let item = lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.map(|i| i.unwrap().unwrap().freeze());
assert_eq!(
item,
Poll::Ready(Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"))
);
let item = lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.map(|i| i.is_none());
assert_eq!(item, Poll::Ready(true));
}
#[ntex::test]
async fn test_read_error() {
let (client, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);
client.read_pending();
assert!(lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.is_pending());
client.write(b"GET /test HTTP/1.1\r\n\r\n");
client.read_error(io::Error::new(io::ErrorKind::Other, "error"));
let item = lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.map(|i| i.unwrap().unwrap().freeze());
assert_eq!(
item,
Poll::Ready(Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"))
);
assert_eq!(
lazy(|cx| Pin::new(&mut server).next_item(cx))
.await
.map(|i| i.unwrap().is_err()),
Poll::Ready(true)
);
}
}

View file

@ -1,55 +1,10 @@
//! Utilities for encoding and decoding frames.
//!
//! Contains adapters to go from streams of bytes, [`AsyncRead`] and
//! [`AsyncWrite`], to framed streams implementing `Sink` and `Stream`.
//! Framed streams are also known as `transports`.
//!
//! [`AsyncRead`]: #
//! [`AsyncWrite`]: #
#![deny(rust_2018_idioms, warnings)]
use std::{io, mem::MaybeUninit, pin::Pin, task::Context, task::Poll};
//! Utilities for encoding and decoding frames.
mod bcodec;
mod decoder;
mod encoder;
mod framed;
pub use self::bcodec::BytesCodec;
pub use self::decoder::Decoder;
pub use self::encoder::Encoder;
pub use self::framed::{Framed, FramedParts};
pub use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use ntex_bytes::{BufMut, BytesMut};
pub fn poll_read_buf<T: AsyncRead>(
io: Pin<&mut T>,
cx: &mut Context<'_>,
buf: &mut BytesMut,
) -> Poll<io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let n = {
let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) };
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
if io.poll_read(cx, &mut buf)?.is_pending() {
return Poll::Pending;
}
// Ensure the pointer does not change from under us
assert_eq!(ptr, buf.filled().as_ptr());
buf.filled().len()
};
// Safety: This is guaranteed to be the number of initialized (and read)
// bytes due to the invariants provided by `ReadBuf::filled`.
unsafe {
buf.advance_mut(n);
}
Poll::Ready(Ok(n))
}

View file

@ -1,6 +1,6 @@
//! Framed transport dispatcher
use std::{
cell::Cell, future::Future, io, pin::Pin, rc::Rc, task::Context, task::Poll, time,
cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll, time,
};
use ntex_bytes::Pool;
@ -70,7 +70,6 @@ enum DispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
Io(io::Error),
}
enum PollService<U: Encoder + Decoder> {
@ -157,7 +156,7 @@ where
///
/// By default disconnect timeout is set to 1 seconds.
pub fn disconnect_timeout(self, val: Seconds) -> Self {
self.inner.state.set_disconnect_timeout(val);
self.inner.state.set_disconnect_timeout(val.into());
self
}
}
@ -176,12 +175,7 @@ where
Ok(Some(val)) => match write.encode(val, &self.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(Either::Left(err)) => {
self.error.set(Some(DispatcherError::Encoder(err)))
}
Err(Either::Right(err)) => {
self.error.set(Some(DispatcherError::Io(err)))
}
Err(err) => self.error.set(Some(DispatcherError::Encoder(err))),
},
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
Ok(None) => return,
@ -407,12 +401,7 @@ where
Ok(Some(item)) => match write.encode(item, &self.shared.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(Either::Left(err)) => {
self.shared.error.set(Some(DispatcherError::Encoder(err)))
}
Err(Either::Right(err)) => {
self.shared.error.set(Some(DispatcherError::Io(err)))
}
Err(err) => self.shared.error.set(Some(DispatcherError::Encoder(err))),
},
Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))),
Ok(None) => (),
@ -443,9 +432,6 @@ where
DispatcherError::Encoder(err) => {
PollService::Item(DispatchItem::EncoderError(err))
}
DispatcherError::Io(err) => {
PollService::Item(DispatchItem::Disconnect(Some(err)))
}
DispatcherError::Service(err) => {
self.error.set(Some(err));
PollService::ServiceError

View file

@ -46,13 +46,7 @@ impl ReadFilter for DefaultFilter {
#[inline]
fn read_closed(&self, err: Option<io::Error>) {
if err.is_some() {
self.0.error.set(err);
}
self.0.write_task.wake();
self.0.dispatch_task.wake();
self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.0.notify_disconnect();
self.0.set_error(err);
}
#[inline]
@ -109,13 +103,9 @@ impl WriteFilter for DefaultFilter {
#[inline]
fn write_closed(&self, err: Option<io::Error>) {
if err.is_some() {
self.0.error.set(err);
}
self.0.read_task.wake();
self.0.set_error(err);
self.0.insert_flags(Flags::IO_CLOSED);
self.0.dispatch_task.wake();
self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.0.notify_disconnect();
}
#[inline]

View file

@ -1,4 +1,4 @@
use std::{fmt, future::Future, io, task::Context, task::Poll};
use std::{any::Any, any::TypeId, fmt, future::Future, io, task::Context, task::Poll};
pub mod testing;
@ -14,12 +14,12 @@ mod tokio_impl;
use ntex_bytes::BytesMut;
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::Millis;
use ntex_util::{channel::oneshot::Receiver, future::Either, time::Millis};
pub use self::dispatcher::Dispatcher;
pub use self::filter::DefaultFilter;
pub use self::state::{Io, IoRef, ReadRef, WriteRef};
pub use self::tasks::{ReadState, WriteState};
pub use self::state::{Io, IoRef, OnDisconnect, ReadRef, WriteRef};
pub use self::tasks::{ReadContext, WriteContext};
pub use self::time::Timer;
pub use self::utils::{filter_factory, from_iostream, into_boxed, into_io};
@ -55,8 +55,15 @@ pub trait WriteFilter {
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>;
}
pub trait Filter: ReadFilter + WriteFilter {
pub trait Filter: ReadFilter + WriteFilter + 'static {
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>>;
fn query(
&self,
id: TypeId,
) -> Either<Option<Box<dyn Any>>, Receiver<Option<Box<dyn Any>>>> {
Either::Left(None)
}
}
pub trait FilterFactory<F: Filter>: Sized {
@ -69,7 +76,7 @@ pub trait FilterFactory<F: Filter>: Sized {
}
pub trait IoStream {
fn start(self, _: ReadState, _: WriteState);
fn start(self, _: ReadContext, _: WriteContext);
}
/// Framed transport item

View file

@ -1,14 +1,14 @@
use std::cell::{Cell, RefCell};
use std::task::{Context, Poll};
use std::{future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
use std::{fmt, future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
use ntex_bytes::{BytesMut, PoolId, PoolRef};
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::{Millis, Seconds};
use ntex_util::time::Millis;
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker};
use super::filter::{DefaultFilter, NullFilter};
use super::tasks::{ReadState, WriteState};
use super::tasks::{ReadContext, WriteContext};
use super::{Filter, FilterFactory, IoStream};
bitflags::bitflags! {
@ -21,13 +21,15 @@ bitflags::bitflags! {
const IO_FILTERS_TO = 0b0000_0000_0000_0100;
/// shutdown io tasks
const IO_SHUTDOWN = 0b0000_0000_0000_1000;
/// io object is closed
const IO_CLOSED = 0b0000_0000_0001_0000;
/// pause io read
const RD_PAUSED = 0b0000_0000_0000_1000;
const RD_PAUSED = 0b0000_0000_0010_0000;
/// new data is available
const RD_READY = 0b0000_0000_0001_0000;
const RD_READY = 0b0000_0000_0100_0000;
/// read buffer is full
const RD_BUF_FULL = 0b0000_0000_0010_0000;
const RD_BUF_FULL = 0b0000_0000_1000_0000;
/// wait write completion
const WR_WAIT = 0b0000_0001_0000_0000;
@ -103,8 +105,22 @@ impl IoStateInner {
}
#[inline]
fn is_io_err(&self) -> bool {
self.flags.get().contains(Flags::IO_ERR)
fn is_io_open(&self) -> bool {
!self.flags.get().intersects(
Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_SHUTDOWN | Flags::IO_CLOSED,
)
}
#[inline]
pub(super) fn set_error(&self, err: Option<io::Error>) {
if err.is_some() {
self.error.set(err);
}
self.read_task.wake();
self.write_task.wake();
self.dispatch_task.wake();
self.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.notify_disconnect();
}
#[inline]
@ -195,7 +211,7 @@ impl Io {
let io_ref = IoRef(inner);
// start io tasks
io.start(ReadState(io_ref.clone()), WriteState(io_ref.clone()));
io.start(ReadContext(io_ref.clone()), WriteContext(io_ref.clone()));
Io(io_ref, FilterItem::Ptr(Box::into_raw(filter)))
}
@ -218,8 +234,8 @@ impl<F> Io<F> {
#[inline]
/// Set io disconnect timeout in secs
pub fn set_disconnect_timeout(&self, timeout: Seconds) {
self.0 .0.disconnect_timeout.set(timeout.into());
pub fn set_disconnect_timeout(&self, timeout: Millis) {
self.0 .0.disconnect_timeout.set(timeout);
}
}
@ -242,31 +258,6 @@ impl<F> Io<F> {
pub fn register_dispatcher(&self, cx: &mut Context<'_>) {
self.0 .0.dispatch_task.register(cx.waker());
}
#[inline]
/// Mark dispatcher as stopped
pub fn dispatcher_stopped(&self) {
self.0 .0.insert_flags(Flags::DSP_STOP);
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub fn init_shutdown(&self, cx: &mut Context<'_>) {
let flags = self.0 .0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?}", flags);
self.0 .0.insert_flags(Flags::IO_FILTERS);
if let Err(err) = self.0 .0.shutdown_filters(&self.0) {
self.0 .0.error.set(Some(err));
self.0 .0.insert_flags(Flags::IO_ERR);
}
self.0 .0.read_task.wake();
self.0 .0.write_task.wake();
self.0 .0.dispatch_task.register(cx.waker());
}
}
}
impl IoRef {
@ -284,9 +275,9 @@ impl IoRef {
}
#[inline]
/// Check if io error occured in read or write task
pub fn is_io_err(&self) -> bool {
self.0.is_io_err()
/// Check if io is still active
pub fn is_io_open(&self) -> bool {
self.0.is_io_open()
}
#[inline]
@ -304,10 +295,13 @@ impl IoRef {
#[inline]
/// Check if io stream is closed
pub fn is_closed(&self) -> bool {
self.0
.flags
.get()
.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::DSP_STOP)
self.0.flags.get().intersects(
Flags::IO_ERR
| Flags::IO_SHUTDOWN
| Flags::IO_CLOSED
| Flags::IO_FILTERS
| Flags::DSP_STOP,
)
}
#[inline]
@ -316,6 +310,12 @@ impl IoRef {
self.0.error.take()
}
#[inline]
/// Mark dispatcher as stopped
pub fn stop_dispatcher(&self) {
self.0.insert_flags(Flags::DSP_STOP);
}
#[inline]
/// Reset keep-alive error
pub fn reset_keepalive(&self) {
@ -360,9 +360,15 @@ impl IoRef {
pub fn on_disconnect(&self) -> OnDisconnect {
OnDisconnect::new(self.0.clone(), self.0.flags.get().contains(Flags::IO_ERR))
}
#[inline]
/// Query specific data
pub fn query<T: 'static>(&self) -> Option<T> {
todo!()
}
}
impl<F> Io<F> {
impl IoRef {
#[inline]
/// Read incoming io stream and decode codec item.
pub async fn next<U>(
@ -375,18 +381,18 @@ impl<F> Io<F> {
let read = self.read();
loop {
let mut buf = self.0 .0.read_buf.take();
let mut buf = self.0.read_buf.take();
let item = if let Some(ref mut buf) = buf {
codec.decode(buf)
} else {
Ok(None)
};
self.0 .0.read_buf.set(buf);
self.0.read_buf.set(buf);
return match item {
Ok(Some(el)) => Ok(Some(el)),
Ok(None) => {
self.0 .0.remove_flags(Flags::RD_READY);
self.0.remove_flags(Flags::RD_READY);
if poll_fn(|cx| read.poll_ready(cx))
.await
.map_err(Either::Right)?
@ -411,53 +417,53 @@ impl<F> Io<F> {
where
U: Encoder,
{
let filter = self.0 .0.filter.get();
let filter = self.0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
let is_write_sleep = buf.is_empty();
codec.encode(item, &mut buf).map_err(Either::Left)?;
filter.release_write_buf(buf).map_err(Either::Right)?;
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0.insert_flags(Flags::WR_WAIT);
if is_write_sleep {
self.0 .0.write_task.wake();
self.0.write_task.wake();
}
poll_fn(|cx| self.write().poll_flush(cx))
poll_fn(|cx| self.write().poll_flush(cx, true))
.await
.map_err(Either::Right)?;
Ok(())
}
#[inline]
/// Shuts down connection
pub async fn shutdown(&self) -> Result<(), io::Error> {
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
Ok(())
} else {
poll_fn(|cx| {
let flags = self.flags();
if !flags.contains(Flags::IO_FILTERS) {
self.init_shutdown(cx);
}
/// Shut down connection
pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let flags = self.flags();
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
if let Some(err) = self.0 .0.error.take() {
Poll::Ready(Err(err))
} else {
Poll::Ready(Ok(()))
}
} else {
self.0 .0.insert_flags(Flags::IO_FILTERS);
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
}
})
.await
if flags.intersects(Flags::IO_ERR | Flags::IO_CLOSED) {
Poll::Ready(Ok(()))
} else {
if !flags.contains(Flags::IO_FILTERS) {
self.init_shutdown(cx);
}
self.0.insert_flags(Flags::IO_FILTERS);
if let Some(err) = self.0.error.take() {
Poll::Ready(Err(err))
} else {
self.0.dispatch_task.register(cx.waker());
Poll::Pending
}
}
}
#[inline]
/// Shut down connection
pub async fn shutdown(&self) -> Result<(), io::Error> {
poll_fn(|cx| self.poll_shutdown(cx)).await
}
#[inline]
#[allow(clippy::type_complexity)]
pub fn poll_next<U>(
@ -468,38 +474,48 @@ impl<F> Io<F> {
where
U: Decoder,
{
if self
.read()
.poll_ready(cx)
.map_err(Either::Right)?
.is_ready()
{
let mut buf = self.0 .0.read_buf.take();
let item = if let Some(ref mut buf) = buf {
codec.decode(buf)
} else {
Ok(None)
};
self.0 .0.read_buf.set(buf);
let read = self.read();
match item {
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
Ok(None) => {
if let Poll::Ready(res) =
self.read().poll_ready(cx).map_err(Either::Right)?
{
if res.is_none() {
return Poll::Ready(Ok(None));
}
match read.decode(codec) {
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
Ok(None) => {
if let Poll::Ready(res) = read.poll_ready(cx).map_err(Either::Right)? {
if res.is_none() {
return Poll::Ready(Ok(None));
}
Poll::Pending
}
Err(err) => Poll::Ready(Err(Either::Left(err))),
Poll::Pending
}
} else {
Poll::Pending
Err(err) => Poll::Ready(Err(Either::Left(err))),
}
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, cx: &mut Context<'_>) {
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?}", flags);
self.0.insert_flags(Flags::IO_FILTERS);
if let Err(err) = self.0.shutdown_filters(self) {
self.0.error.set(Some(err));
self.0.insert_flags(Flags::IO_ERR);
}
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.register(cx.waker());
}
}
}
impl fmt::Debug for IoRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IoRef")
.field("open", &!self.is_closed())
.finish()
}
}
impl<F: Filter> Io<F> {
@ -576,7 +592,10 @@ impl<F: Filter> Io<F> {
impl<F> Drop for Io<F> {
fn drop(&mut self) {
log::trace!("stopping io stream");
log::trace!(
"io is dropped, force stopping io streams {:?}",
self.0.flags()
);
if let FilterItem::Ptr(p) = self.1 {
if p.is_null() {
return;
@ -635,7 +654,7 @@ impl<'a> WriteRef<'a> {
///
/// Write task must be waken up separately.
pub fn enable_backpressure(&self, cx: Option<&mut Context<'_>>) {
log::trace!("enable write back-pressure");
log::trace!("enable write back-pressure {:?}", cx.is_some());
self.0.insert_flags(Flags::WR_BACKPRESSURE);
if let Some(cx) = cx {
self.0.dispatch_task.register(cx.waker());
@ -669,7 +688,7 @@ impl<'a> WriteRef<'a> {
&self,
item: U::Item,
codec: &U,
) -> Result<bool, Either<<U as Encoder>::Error, io::Error>>
) -> Result<bool, <U as Encoder>::Error>
where
U: Encoder,
{
@ -690,28 +709,44 @@ impl<'a> WriteRef<'a> {
}
// encode item and wake write task
let result = codec
.encode(item, &mut buf)
.map(|_| {
if is_write_sleep {
self.0.write_task.wake();
}
buf.len() < hw
})
.map_err(Either::Left);
filter.release_write_buf(buf).map_err(Either::Right)?;
Ok(result?)
let result = codec.encode(item, &mut buf).map(|_| {
if is_write_sleep {
self.0.write_task.wake();
}
buf.len() < hw
});
if let Err(err) = filter.release_write_buf(buf) {
self.0.set_error(Some(err));
}
result
} else {
Ok(true)
}
}
#[inline]
/// Wake write task and instruct to write all data.
/// Wake write task and instruct to write data.
///
/// When write task is done wake dispatcher.
pub fn poll_flush(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.0.insert_flags(Flags::WR_WAIT);
/// If full is true then wake up dispatcher when all data is flushed
/// otherwise wake up when size of write buffer is lower than
/// buffer max size.
pub fn poll_flush(
&self,
cx: &mut Context<'_>,
full: bool,
) -> Poll<Result<(), io::Error>> {
// check io error
if !self.0.is_io_open() {
return Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "disconnected")
})));
}
if full {
self.0.insert_flags(Flags::WR_WAIT);
} else {
self.0.insert_flags(Flags::WR_BACKPRESSURE);
}
if let Some(buf) = self.0.write_buf.take() {
if !buf.is_empty() {
@ -722,14 +757,16 @@ impl<'a> WriteRef<'a> {
}
}
if self.0.is_io_err() {
Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "disconnected")
})))
} else {
self.0.dispatch_task.register(cx.waker());
Poll::Ready(Ok(()))
}
// self.0.dispatch_task.register(cx.waker());
Poll::Ready(Ok(()))
}
#[inline]
/// Wake write task and instruct to write data.
///
/// This is async version of .poll_flush() method.
pub async fn flush(&self, full: bool) -> Result<(), io::Error> {
poll_fn(|cx| self.poll_flush(cx, full)).await
}
}
@ -834,7 +871,7 @@ impl<'a> ReadRef<'a> {
let mut flags = self.0.flags.get();
let ready = flags.contains(Flags::RD_READY);
if self.0.is_io_err() {
if !self.0.is_io_open() {
if let Some(err) = self.0.error.take() {
Poll::Ready(Err(err))
} else {
@ -843,7 +880,6 @@ impl<'a> ReadRef<'a> {
} else if ready {
Poll::Ready(Ok(Some(())))
} else {
flags.remove(Flags::RD_READY);
if flags.contains(Flags::RD_BUF_FULL) {
log::trace!("read back-pressure is enabled, wake io task");
flags.remove(Flags::RD_BUF_FULL);
@ -939,7 +975,6 @@ mod tests {
#[ntex::test]
async fn utils() {
env_logger::init();
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write(TEXT);
@ -1041,7 +1076,7 @@ mod tests {
in_bytes: Rc<Cell<usize>>,
out_bytes: Rc<Cell<usize>>,
}
impl<F: ReadFilter + WriteFilter> Filter for Counter<F> {
impl<F: ReadFilter + WriteFilter + 'static> Filter for Counter<F> {
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}

View file

@ -4,9 +4,9 @@ use ntex_bytes::{BytesMut, PoolRef};
use super::{state::Flags, IoRef, WriteReadiness};
pub struct ReadState(pub(super) IoRef);
pub struct ReadContext(pub(super) IoRef);
impl ReadState {
impl ReadContext {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0 .0.pool.get()
@ -60,9 +60,9 @@ impl ReadState {
}
}
pub struct WriteState(pub(super) IoRef);
pub struct WriteContext(pub(super) IoRef);
impl WriteState {
impl WriteContext {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0 .0.pool.get()

View file

@ -1,11 +1,13 @@
use std::cell::{Cell, RefCell};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::{cmp, fmt, io, mem};
use std::{cmp, fmt, future::Future, io, mem, pin::Pin, rc::Rc};
use ntex_bytes::{BufMut, BytesMut};
use ntex_bytes::{Buf, BufMut, BytesMut};
use ntex_util::future::poll_fn;
use ntex_util::time::{sleep, Millis};
use ntex_util::time::{sleep, Millis, Sleep};
use crate::{IoStream, ReadContext, WriteContext, WriteReadiness};
#[derive(Default)]
struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
@ -441,138 +443,181 @@ mod tokio {
}
}
#[cfg(not(feature = "tokio"))]
mod non_tokio {
impl IoStream for IoTest {
fn start(self, read: ReadState, write: WriteState) {
let io = Rc::new(self);
impl IoStream for IoTest {
fn start(self, read: ReadContext, write: WriteContext) {
let io = Rc::new(self);
ntex_util::spawn(ReadTask {
io: io.clone(),
state: read,
});
ntex_util::spawn(WriteTask {
io,
state: write,
st: IoWriteState::Processing,
});
}
ntex_util::spawn(ReadTask {
io: io.clone(),
state: read,
});
ntex_util::spawn(WriteTask {
io,
state: write,
st: IoWriteState::Processing(None),
});
}
}
/// Read io task
struct ReadTask {
io: Rc<IoTest>,
state: ReadState,
}
/// Read io task
struct ReadTask {
io: Rc<IoTest>,
state: ReadContext,
}
impl Future for ReadTask {
type Output = ();
impl Future for ReadTask {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(Err(())) => {
log::trace!("read task is instructed to terminate");
Poll::Ready(())
}
Poll::Ready(Ok(())) => {
let io = &this.io;
let pool = this.state.memory_pool();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
match this.state.poll_ready(cx) {
Poll::Ready(Err(())) => {
log::trace!("read task is instructed to terminate");
Poll::Ready(())
}
Poll::Ready(Ok(())) => {
let io = &this.io;
let pool = this.state.memory_pool();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match io.poll_read_buf(cx, &mut buf) {
Poll::Pending => {
log::trace!("no more data in io stream");
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
this.state.release_read_buf(buf, new_bytes);
this.state.close(None);
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
}
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
this.state.release_read_buf(buf, new_bytes);
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
enum IoWriteState {
Processing,
Shutdown(Option<Sleep>, Shutdown),
}
#[derive(Debug)]
enum Shutdown {
None,
Flushed,
Stopping,
}
/// Write io task
struct WriteTask {
st: IoWriteState,
io: Rc<IoTest>,
state: WriteState,
}
impl Future for WriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
// flush framed instance
match flush_io(&this.io, &this.state, cx) {
Poll::Pending | Poll::Ready(true) => Poll::Pending,
Poll::Ready(false) => Poll::Ready(()),
match io.poll_read_buf(cx, &mut buf) {
Poll::Pending => {
log::trace!("no more data in io stream");
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(None);
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(WriteReadiness::Shutdown)) => {
log::trace!("write task is instructed to shutdown");
this.st = IoWriteState::Shutdown(
this.state.disconnect_timeout().map(sleep),
Shutdown::None,
);
self.poll(cx)
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
log::trace!("write task is instructed to terminate");
}
}
let _ = this.state.release_read_buf(buf, new_bytes);
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
enum IoWriteState {
Processing(Option<Sleep>),
Shutdown(Option<Sleep>, Shutdown),
}
#[derive(Debug)]
enum Shutdown {
None,
Flushed,
Stopping,
}
/// Write io task
struct WriteTask {
st: IoWriteState,
io: Rc<IoTest>,
state: WriteContext,
}
impl Future for WriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
// flush framed instance
match flush_io(&this.io, &this.state, cx) {
Poll::Pending | Poll::Ready(true) => Poll::Pending,
Poll::Ready(false) => Poll::Ready(()),
}
}
Poll::Ready(Err(WriteReadiness::Timeout(time))) => {
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Shutdown(time))) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
delay
} else {
sleep(time)
};
this.st = IoWriteState::Shutdown(Some(timeout), Shutdown::None);
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
log::trace!("write task is instructed to terminate");
// shutdown WRITE side
this.io
.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(Flags::CLOSED);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::None => {
// flush write buffer
match flush_io(&this.io, &this.state, cx) {
Poll::Ready(true) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(false) => {
log::trace!(
"write task is closed with err during flush"
);
return Poll::Ready(());
}
_ => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
this.io
.local
@ -581,143 +626,102 @@ mod non_tokio {
.borrow_mut()
.flags
.insert(Flags::CLOSED);
this.state.close(None);
Poll::Ready(())
*st = Shutdown::Stopping;
continue;
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::None => {
// flush write buffer
match flush_io(&this.io, &this.state, cx) {
Poll::Ready(true) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(false) => {
log::trace!(
"write task is closed with err during flush"
);
Shutdown::Stopping => {
// read until 0 or err
let io = &this.io;
loop {
let mut buf = BytesMut::new();
match io.poll_read_buf(cx, &mut buf) {
Poll::Ready(Err(e)) => {
this.state.close(Some(e));
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Ready(Ok(n)) if n == 0 => {
this.state.close(None);
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Pending => break,
_ => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
this.io
.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(Flags::CLOSED);
*st = Shutdown::Stopping;
continue;
}
Shutdown::Stopping => {
// read until 0 or err
let io = &this.io;
loop {
let mut buf = BytesMut::new();
match io.poll_read_buf(cx, &mut buf) {
Poll::Ready(Err(e)) => {
this.state.close(Some(e));
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Ready(Ok(n)) if n == 0 => {
this.state.close(None);
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Pending => break,
_ => (),
}
}
}
}
// disconnect timeout
if let Some(ref delay) = delay {
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
}
log::trace!("write task is stopped after delay");
this.state.close(None);
return Poll::Ready(());
}
// disconnect timeout
if let Some(ref delay) = delay {
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
}
log::trace!("write task is stopped after delay");
this.state.close(None);
return Poll::Ready(());
}
}
}
}
}
/// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io(
io: &IoTest,
state: &WriteState,
cx: &mut Context<'_>,
) -> Poll<bool> {
let mut buf = if let Some(buf) = state.get_write_buf() {
buf
} else {
return Poll::Ready(true);
};
let len = buf.len();
let pool = state.memory_pool();
/// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io(
io: &IoTest,
state: &WriteContext,
cx: &mut Context<'_>,
) -> Poll<bool> {
let mut buf = if let Some(buf) = state.get_write_buf() {
buf
} else {
return Poll::Ready(true);
};
let len = buf.len();
if len != 0 {
log::trace!("flushing framed transport: {}", len);
if len != 0 {
log::trace!("flushing framed transport: {}", len);
let mut written = 0;
while written < len {
match io.poll_write_buf(cx, &buf[written..]) {
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!(
"disconnected during flush, written {}",
written
);
pool.release_write_buf(buf);
state.close(Some(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
return Poll::Ready(false);
} else {
written += n
}
}
Poll::Pending => break,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
pool.release_write_buf(buf);
state.close(Some(e));
let mut written = 0;
while written < len {
match io.poll_write_buf(cx, &buf[written..]) {
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("disconnected during flush, written {}", written);
let _ = state.release_write_buf(buf);
state.close(Some(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
return Poll::Ready(false);
} else {
written += n
}
}
Poll::Pending => break,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
let _ = state.release_write_buf(buf);
state.close(Some(e));
return Poll::Ready(false);
}
}
log::trace!("flushed {} bytes", written);
// remove written data
if written == len {
buf.clear();
state.release_write_buf(buf);
Poll::Ready(true)
} else {
buf.advance(written);
state.release_write_buf(buf);
Poll::Pending
}
} else {
Poll::Ready(true)
}
log::trace!("flushed {} bytes", written);
// remove written data
if written == len {
buf.clear();
let _ = state.release_write_buf(buf);
Poll::Ready(true)
} else {
buf.advance(written);
let _ = state.release_write_buf(buf);
Poll::Pending
}
} else {
Poll::Ready(true)
}
}

View file

@ -3,15 +3,12 @@ use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc};
use ntex_bytes::{Buf, BufMut};
use ntex_util::time::{sleep, Sleep};
use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf};
use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf, net::TcpStream};
use super::{IoStream, ReadState, WriteReadiness, WriteState};
use super::{IoStream, ReadContext, WriteContext, WriteReadiness};
impl<T> IoStream for T
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn start(self, read: ReadState, write: WriteState) {
impl IoStream for TcpStream {
fn start(self, read: ReadContext, write: WriteContext) {
let io = Rc::new(RefCell::new(self));
ntex_util::spawn(ReadTask::new(io.clone(), read));
@ -19,26 +16,29 @@ where
}
}
/// Read io task
struct ReadTask<T> {
io: Rc<RefCell<T>>,
state: ReadState,
#[cfg(unix)]
impl IoStream for tok_io::net::UnixStream {
fn start(self, _read: ReadContext, _write: WriteContext) {
let _io = Rc::new(RefCell::new(self));
todo!()
}
}
impl<T> ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Read io task
struct ReadTask {
io: Rc<RefCell<TcpStream>>,
state: ReadContext,
}
impl ReadTask {
/// Create new read io task
fn new(io: Rc<RefCell<T>>, state: ReadState) -> Self {
fn new(io: Rc<RefCell<TcpStream>>, state: ReadContext) -> Self {
Self { io, state }
}
}
impl<T> Future for ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
impl Future for ReadTask {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@ -119,18 +119,15 @@ enum Shutdown {
}
/// Write io task
struct WriteTask<T> {
struct WriteTask {
st: IoWriteState,
io: Rc<RefCell<T>>,
state: WriteState,
io: Rc<RefCell<TcpStream>>,
state: WriteContext,
}
impl<T> WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
impl WriteTask {
/// Create new write io task
fn new(io: Rc<RefCell<T>>, state: WriteState) -> Self {
fn new(io: Rc<RefCell<TcpStream>>, state: WriteContext) -> Self {
Self {
io,
state,
@ -139,10 +136,7 @@ where
}
}
impl<T> Future for WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
impl Future for WriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@ -272,7 +266,7 @@ where
/// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
io: &mut T,
state: &WriteState,
state: &WriteContext,
cx: &mut Context<'_>,
) -> Poll<bool> {
let mut buf = if let Some(buf) = state.get_write_buf() {
@ -284,12 +278,14 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
let pool = state.memory_pool();
if len != 0 {
// log::trace!("flushing framed transport: {:?}", buf);
//log::trace!("flushing framed transport: {:?}", buf);
let mut written = 0;
while written < len {
match Pin::new(&mut *io).poll_write(cx, &buf[written..]) {
Poll::Pending => break,
Poll::Pending => {
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("Disconnected during flush, written {}", written);
@ -311,7 +307,7 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
}
}
}
// log::trace!("flushed {} bytes", written);
//log::trace!("flushed {} bytes", written);
// remove written data
let result = if written == len {

View file

@ -22,6 +22,6 @@ ntex-util = "0.1.2"
openssl = "0.10.32"
[dev-dependencies]
ntex = { version = "0.4.14", features = ["openssl"] }
ntex = { version = "0.5.0", features = ["openssl"] }
futures = "0.3"
env_logger = "0.9"

View file

@ -10,6 +10,13 @@ use ntex_io::{
use ntex_util::{future::poll_fn, time, time::Millis};
use openssl::ssl::{self, SslStream};
/// Selected alpn protocol
pub enum AlpnHttpProtocol {
Http1,
Http2,
}
/// An implementation of SSL streams
pub struct SslFilter<F> {
inner: RefCell<SslStream<IoInner<F>>>,
}
@ -191,7 +198,7 @@ impl SslAcceptor {
/// Set handshake timeout.
///
/// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
pub fn timeout<U: Into<Millis>>(&mut self, timeout: U) -> &mut Self {
self.timeout = timeout.into();
self
}
@ -209,7 +216,7 @@ impl Clone for SslAcceptor {
impl<F: Filter + 'static> FilterFactory<F> for SslAcceptor {
type Filter = SslFilter<F>;
type Error = io::Error;
type Error = Box<dyn Error>;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, Self::Error>>>>;
fn create(self, st: Io<F>) -> Self::Future {
@ -225,8 +232,7 @@ impl<F: Filter + 'static> FilterFactory<F> for SslAcceptor {
read_buf: None,
write_buf: None,
};
let ssl_stream =
ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?;
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
Ok(SslFilter {
inner: RefCell::new(ssl_stream),
@ -234,9 +240,9 @@ impl<F: Filter + 'static> FilterFactory<F> for SslAcceptor {
})?;
poll_fn(|cx| {
let _ = st.write().poll_flush(cx)?;
let _ = st.write().poll_flush(cx, true)?;
handle_result(st.filter().inner.borrow_mut().accept(), &st, cx)
.map_err(map_to_ioerr)
.map_err(Into::<Box<dyn Error>>::into)
})
.await?;
@ -244,7 +250,7 @@ impl<F: Filter + 'static> FilterFactory<F> for SslAcceptor {
})
.await
.map_err(|_| {
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout")
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into()
})
.and_then(|item| item)
})
@ -265,7 +271,7 @@ impl SslConnector {
impl<F: Filter + 'static> FilterFactory<F> for SslConnector {
type Filter = SslFilter<F>;
type Error = io::Error;
type Error = Box<dyn Error>;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, Self::Error>>>>;
fn create(self, st: Io<F>) -> Self::Future {
@ -277,8 +283,7 @@ impl<F: Filter + 'static> FilterFactory<F> for SslConnector {
read_buf: None,
write_buf: None,
};
let ssl_stream =
ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?;
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
Ok(SslFilter {
inner: RefCell::new(ssl_stream),
@ -286,9 +291,9 @@ impl<F: Filter + 'static> FilterFactory<F> for SslConnector {
})?;
poll_fn(|cx| {
let _ = st.write().poll_flush(cx)?;
let _ = st.write().poll_flush(cx, true)?;
handle_result(st.filter().inner.borrow_mut().connect(), &st, cx)
.map_err(map_to_ioerr)
.map_err(Into::<Box<dyn Error>>::into)
})
.await?;

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.4.14"
version = "0.5.0"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"

View file

@ -1,148 +0,0 @@
use std::task::{Context, Poll};
use std::{future::Future, pin::Pin};
use crate::io::Io;
use crate::service::{Service, ServiceFactory};
use crate::util::{PoolId, PoolRef, Ready};
use super::service::ConnectServiceResponse;
use super::{Address, Connect, ConnectError, Connector};
pub struct IoConnector<T> {
inner: Connector<T>,
pool: PoolRef,
}
impl<T> IoConnector<T> {
/// Construct new connect service with custom dns resolver
pub fn new() -> Self {
IoConnector {
inner: Connector::new(),
pool: PoolId::P0.pool_ref(),
}
}
/// Set memory pool.
///
/// Use specified memory pool for memory allocations. By default P0
/// memory pool is used.
pub fn memory_pool(mut self, id: PoolId) -> Self {
self.pool = id.pool_ref();
self
}
}
impl<T: Address> IoConnector<T> {
/// Resolve and connect to remote host
pub fn connect<U>(&self, message: U) -> IoConnectServiceResponse<T>
where
Connect<T>: From<U>,
{
IoConnectServiceResponse {
inner: self.inner.call(message.into()),
pool: self.pool,
}
}
}
impl<T> Default for IoConnector<T> {
fn default() -> Self {
IoConnector::new()
}
}
impl<T> Clone for IoConnector<T> {
fn clone(&self) -> Self {
IoConnector {
inner: self.inner.clone(),
pool: self.pool,
}
}
}
impl<T: Address> ServiceFactory for IoConnector<T> {
type Request = Connect<T>;
type Response = Io;
type Error = ConnectError;
type Config = ();
type Service = IoConnector<T>;
type InitError = ();
type Future = Ready<Self::Service, Self::InitError>;
#[inline]
fn new_service(&self, _: ()) -> Self::Future {
Ready::Ok(self.clone())
}
}
impl<T: Address> Service for IoConnector<T> {
type Request = Connect<T>;
type Response = Io;
type Error = ConnectError;
type Future = IoConnectServiceResponse<T>;
#[inline]
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&self, req: Connect<T>) -> Self::Future {
self.connect(req)
}
}
#[doc(hidden)]
pub struct IoConnectServiceResponse<T: Address> {
inner: ConnectServiceResponse<T>,
pool: PoolRef,
}
impl<T: Address> Future for IoConnectServiceResponse<T> {
type Output = Result<Io, ConnectError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.inner).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(stream)) => {
Poll::Ready(Ok(Io::with_memory_pool(stream, self.pool)))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[crate::rt_test]
async fn test_connect() {
let server = crate::server::test_server(|| {
crate::service::fn_service(|_| async { Ok::<_, ()>(()) })
});
let srv = IoConnector::default();
let result = srv.connect("").await;
assert!(result.is_err());
let result = srv.connect("localhost:99999").await;
assert!(result.is_err());
let srv = IoConnector::default();
let result = srv.connect(format!("{}", server.addr())).await;
assert!(result.is_ok());
let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![
format!("127.0.0.1:{}", server.addr().port() - 1)
.parse()
.unwrap(),
server.addr(),
]);
let result = crate::connect::connect(msg).await;
assert!(result.is_ok());
let msg = Connect::new(server.addr());
let result = crate::connect::connect(msg).await;
assert!(result.is_ok());
}
}

View file

@ -2,7 +2,6 @@
use std::future::Future;
mod error;
mod io;
mod message;
mod resolve;
mod service;
@ -13,19 +12,18 @@ mod uri;
#[cfg(feature = "openssl")]
pub mod openssl;
#[cfg(feature = "rustls")]
pub mod rustls;
use crate::rt::net::TcpStream;
//#[cfg(feature = "rustls")]
//pub mod rustls;
pub use self::error::ConnectError;
pub use self::io::IoConnector;
pub use self::message::{Address, Connect};
pub use self::resolve::Resolver;
pub use self::service::Connector;
use crate::io::Io;
/// Resolve and connect to remote host
pub fn connect<T, U>(message: U) -> impl Future<Output = Result<TcpStream, ConnectError>>
pub fn connect<T, U>(message: U) -> impl Future<Output = Result<Io, ConnectError>>
where
T: Address + 'static,
Connect<T>: From<U>,

View file

@ -2,14 +2,12 @@ use std::{future::Future, io, pin::Pin, task::Context, task::Poll};
use ntex_openssl::{SslConnector as IoSslConnector, SslFilter};
pub use open_ssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
pub use tokio_openssl::SslStream;
use crate::io::{DefaultFilter, Io};
use crate::rt::net::TcpStream;
use crate::service::{Service, ServiceFactory};
use crate::util::Ready;
use super::{Address, Connect, ConnectError, Connector, IoConnector as BaseIoConnector};
use super::{Address, Connect, ConnectError, Connector};
pub struct OpensslConnector<T> {
connector: Connector<T>,
@ -27,103 +25,6 @@ impl<T> OpensslConnector<T> {
}
impl<T: Address + 'static> OpensslConnector<T> {
/// Resolve and connect to remote host
pub fn connect<U>(
&self,
message: U,
) -> impl Future<Output = Result<SslStream<TcpStream>, ConnectError>>
where
Connect<T>: From<U>,
{
let message = Connect::from(message);
let host = message.host().to_string();
let conn = self.connector.call(message);
let openssl = self.openssl.clone();
async move {
let io = conn.await?;
trace!("SSL Handshake start for: {:?}", host);
match openssl.configure() {
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()),
Ok(config) => {
let config = config
.into_ssl(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let mut io = SslStream::new(config, io)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
match Pin::new(&mut io).connect().await {
Ok(_) => {
trace!("SSL Handshake success: {:?}", host);
Ok(io)
}
Err(e) => {
trace!("SSL Handshake error: {:?}", e);
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))
.into())
}
}
}
}
}
}
}
impl<T> Clone for OpensslConnector<T> {
fn clone(&self) -> Self {
OpensslConnector {
connector: self.connector.clone(),
openssl: self.openssl.clone(),
}
}
}
impl<T: Address + 'static> ServiceFactory for OpensslConnector<T> {
type Request = Connect<T>;
type Response = SslStream<TcpStream>;
type Error = ConnectError;
type Config = ();
type Service = OpensslConnector<T>;
type InitError = ();
type Future = Ready<Self::Service, Self::InitError>;
fn new_service(&self, _: ()) -> Self::Future {
Ready::Ok(self.clone())
}
}
impl<T: Address + 'static> Service for OpensslConnector<T> {
type Request = Connect<T>;
type Response = SslStream<TcpStream>;
type Error = ConnectError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
#[inline]
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&self, req: Connect<T>) -> Self::Future {
Box::pin(self.connect(req))
}
}
pub struct IoConnector<T> {
connector: BaseIoConnector<T>,
openssl: SslConnector,
}
impl<T> IoConnector<T> {
/// Construct new OpensslConnectService factory
pub fn new(connector: SslConnector) -> Self {
IoConnector {
connector: BaseIoConnector::default(),
openssl: connector,
}
}
}
impl<T: Address + 'static> IoConnector<T> {
/// Resolve and connect to remote host
pub fn connect<U>(
&self,
@ -164,21 +65,21 @@ impl<T: Address + 'static> IoConnector<T> {
}
}
impl<T> Clone for IoConnector<T> {
impl<T> Clone for OpensslConnector<T> {
fn clone(&self) -> Self {
IoConnector {
OpensslConnector {
connector: self.connector.clone(),
openssl: self.openssl.clone(),
}
}
}
impl<T: Address + 'static> ServiceFactory for IoConnector<T> {
impl<T: Address + 'static> ServiceFactory for OpensslConnector<T> {
type Request = Connect<T>;
type Response = Io<SslFilter<DefaultFilter>>;
type Error = ConnectError;
type Config = ();
type Service = IoConnector<T>;
type Service = OpensslConnector<T>;
type InitError = ();
type Future = Ready<Self::Service, Self::InitError>;
@ -187,7 +88,7 @@ impl<T: Address + 'static> ServiceFactory for IoConnector<T> {
}
}
impl<T: Address + 'static> Service for IoConnector<T> {
impl<T: Address + 'static> Service for OpensslConnector<T> {
type Request = Connect<T>;
type Response = Io<SslFilter<DefaultFilter>>;
type Error = ConnectError;

View file

@ -1,14 +1,16 @@
use std::task::{Context, Poll};
use std::{collections::VecDeque, future::Future, io, net::SocketAddr, pin::Pin};
use crate::io::Io;
use crate::rt::net::TcpStream;
use crate::service::{Service, ServiceFactory};
use crate::util::{Either, Ready};
use crate::util::{Either, PoolId, PoolRef, Ready};
use super::{Address, Connect, ConnectError, Resolver};
pub struct Connector<T> {
resolver: Resolver<T>,
pool: PoolRef,
}
impl<T> Connector<T> {
@ -16,8 +18,18 @@ impl<T> Connector<T> {
pub fn new() -> Self {
Connector {
resolver: Resolver::new(),
pool: PoolId::P0.pool_ref(),
}
}
/// Set memory pool.
///
/// Use specified memory pool for memory allocations. By default P0
/// memory pool is used.
pub fn memory_pool(mut self, id: PoolId) -> Self {
self.pool = id.pool_ref();
self
}
}
impl<T: Address> Connector<T> {
@ -25,11 +37,14 @@ impl<T: Address> Connector<T> {
pub fn connect<U>(
&self,
message: U,
) -> impl Future<Output = Result<TcpStream, ConnectError>>
) -> impl Future<Output = Result<Io, ConnectError>>
where
Connect<T>: From<U>,
{
ConnectServiceResponse::new(self.resolver.call(message.into()))
ConnectServiceResponse {
state: ConnectState::Resolve(self.resolver.call(message.into())),
pool: self.pool,
}
}
}
@ -43,13 +58,14 @@ impl<T> Clone for Connector<T> {
fn clone(&self) -> Self {
Connector {
resolver: self.resolver.clone(),
pool: self.pool,
}
}
}
impl<T: Address> ServiceFactory for Connector<T> {
type Request = Connect<T>;
type Response = TcpStream;
type Response = Io;
type Error = ConnectError;
type Config = ();
type Service = Connector<T>;
@ -64,7 +80,7 @@ impl<T: Address> ServiceFactory for Connector<T> {
impl<T: Address> Service for Connector<T> {
type Request = Connect<T>;
type Response = TcpStream;
type Response = Io;
type Error = ConnectError;
type Future = ConnectServiceResponse<T>;
@ -87,18 +103,20 @@ enum ConnectState<T: Address> {
#[doc(hidden)]
pub struct ConnectServiceResponse<T: Address> {
state: ConnectState<T>,
pool: PoolRef,
}
impl<T: Address> ConnectServiceResponse<T> {
pub(super) fn new(fut: <Resolver<T> as Service>::Future) -> Self {
ConnectServiceResponse {
Self {
state: ConnectState::Resolve(fut),
pool: PoolId::P0.pool_ref(),
}
}
}
impl<T: Address> Future for ConnectServiceResponse<T> {
type Output = Result<TcpStream, ConnectError>;
type Output = Result<Io, ConnectError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state {
@ -126,7 +144,12 @@ impl<T: Address> Future for ConnectServiceResponse<T> {
}
}
},
ConnectState::Connect(ref mut fut) => Pin::new(fut).poll(cx),
ConnectState::Connect(ref mut fut) => match Pin::new(fut).poll(cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(stream) => {
Poll::Ready(Ok(Io::with_memory_pool(stream, self.pool)))
}
},
}
}
}

View file

@ -1,901 +0,0 @@
//! Framed transport dispatcher
use std::{
cell::Cell, cell::RefCell, future::Future, pin::Pin, rc::Rc, task::Context,
task::Poll, time, time::Instant,
};
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
use crate::framed::{DispatchItem, Read, ReadTask, State, Timer, Write, WriteTask};
use crate::service::{IntoService, Service};
use crate::time::Seconds;
use crate::util::{Either, Pool};
type Response<U> = <U as Encoder>::Item;
pin_project_lite::pin_project! {
/// Framed dispatcher - is a future that reads frames from Framed object
/// and pass then to the service.
pub struct Dispatcher<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Encoder,
U: Decoder,
<U as Encoder>::Item: 'static,
{
service: S,
inner: DispatcherInner<S, U>,
#[pin]
fut: Option<S::Future>,
}
}
struct DispatcherInner<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Encoder + Decoder,
{
st: Cell<DispatcherState>,
state: State,
timer: Timer,
ka_timeout: Seconds,
ka_updated: Cell<Instant>,
error: Cell<Option<S::Error>>,
shared: Rc<DispatcherShared<S, U>>,
pool: Pool,
}
struct DispatcherShared<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Encoder + Decoder,
{
codec: U,
error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
inflight: Cell<usize>,
}
#[derive(Copy, Clone, Debug)]
enum DispatcherState {
Processing,
Backpressure,
Stop,
Shutdown,
}
enum DispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
}
enum PollService<U: Encoder + Decoder> {
Item(DispatchItem<U>),
ServiceError,
Ready,
}
impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
fn from(err: Either<S, U>) -> Self {
match err {
Either::Left(err) => DispatcherError::Service(err),
Either::Right(err) => DispatcherError::Encoder(err),
}
}
}
impl<S, U> Dispatcher<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + 'static,
<U as Encoder>::Item: 'static,
{
/// Construct new `Dispatcher` instance.
pub fn new<T, F: IntoService<S>>(
io: T,
codec: U,
state: State,
service: F,
timer: Timer,
) -> Self
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
let io = Rc::new(RefCell::new(io));
// start support tasks
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
crate::rt::spawn(WriteTask::new(io, state.clone()));
Self::from_state(codec, state, service, timer)
}
/// Construct new `Dispatcher` instance.
pub fn from_state<F: IntoService<S>>(
codec: U,
state: State,
service: F,
timer: Timer,
) -> Self {
let updated = timer.now();
let ka_timeout = Seconds(30);
// register keepalive timer
let expire = updated + time::Duration::from(ka_timeout);
timer.register(expire, expire, &state);
Dispatcher {
service: service.into_service(),
fut: None,
inner: DispatcherInner {
pool: state.memory_pool().pool(),
ka_updated: Cell::new(updated),
error: Cell::new(None),
st: Cell::new(DispatcherState::Processing),
shared: Rc::new(DispatcherShared {
codec,
error: Cell::new(None),
inflight: Cell::new(0),
}),
state,
timer,
ka_timeout,
},
}
}
/// Set keep-alive timeout.
///
/// To disable timeout set value to 0.
///
/// By default keep-alive timeout is set to 30 seconds.
pub fn keepalive_timeout(mut self, timeout: Seconds) -> Self {
// register keepalive timer
let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka());
if timeout.is_zero() {
self.inner.timer.unregister(prev, &self.inner.state);
} else {
let expire = self.inner.ka_updated.get() + time::Duration::from(timeout);
self.inner.timer.register(expire, prev, &self.inner.state);
}
self.inner.ka_timeout = timeout;
self
}
/// Set connection disconnect timeout in seconds.
///
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
/// within this time, the connection get dropped.
///
/// To disable timeout set value to 0.
///
/// By default disconnect timeout is set to 1 seconds.
pub fn disconnect_timeout(self, val: Seconds) -> Self {
self.inner.state.set_disconnect_timeout(val);
self
}
}
impl<S, U> DispatcherShared<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Encoder + Decoder,
<U as Encoder>::Item: 'static,
{
fn handle_result(&self, item: Result<S::Response, S::Error>, write: Write<'_>) {
self.inflight.set(self.inflight.get() - 1);
match write.encode_result(item, &self.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(err) => self.error.set(Some(err.into())),
}
write.wake_dispatcher();
}
}
impl<S, U> Future for Dispatcher<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + 'static,
<U as Encoder>::Item: 'static,
{
type Output = Result<(), S::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
let slf = &this.inner;
let state = &slf.state;
let read = state.read();
let write = state.write();
// handle service response future
if let Some(fut) = this.fut.as_mut().as_pin_mut() {
match fut.poll(cx) {
Poll::Pending => (),
Poll::Ready(item) => {
this.fut.set(None);
slf.shared.inflight.set(slf.shared.inflight.get() - 1);
slf.handle_result(item, write);
}
}
}
// handle memory pool pressure
if slf.pool.poll_ready(cx).is_pending() {
read.pause(cx.waker());
return Poll::Pending;
}
loop {
match slf.st.get() {
DispatcherState::Processing => {
let result = match slf.poll_service(this.service, cx, read) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => result,
};
let item = match result {
PollService::Ready => {
if !write.is_ready() {
// instruct write task to notify dispatcher when data is flushed
write.enable_backpressure(Some(cx.waker()));
slf.st.set(DispatcherState::Backpressure);
DispatchItem::WBackPressureEnabled
} else if read.is_ready() {
// decode incoming bytes if buffer is ready
match read.decode(&slf.shared.codec) {
Ok(Some(el)) => {
slf.update_keepalive();
DispatchItem::Item(el)
}
Ok(None) => {
log::trace!("not enough data to decode next frame, register dispatch task");
read.wake(cx.waker());
return Poll::Pending;
}
Err(err) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::DecoderError(err)
}
}
} else {
// no new events
state.register_dispatcher(cx.waker());
return Poll::Pending;
}
}
PollService::Item(item) => item,
PollService::ServiceError => continue,
};
// call service
if this.fut.is_none() {
// optimize first service call
this.fut.set(Some(this.service.call(item)));
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
Poll::Ready(res) => {
this.fut.set(None);
slf.handle_result(res, write);
}
Poll::Pending => {
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
}
}
} else {
slf.spawn_service_call(this.service.call(item));
}
}
// handle write back-pressure
DispatcherState::Backpressure => {
let result = match slf.poll_service(this.service, cx, read) {
Poll::Ready(result) => result,
Poll::Pending => return Poll::Pending,
};
let item = match result {
PollService::Ready => {
if write.is_ready() {
slf.st.set(DispatcherState::Processing);
DispatchItem::WBackPressureDisabled
} else {
return Poll::Pending;
}
}
PollService::Item(item) => item,
PollService::ServiceError => continue,
};
// call service
if this.fut.is_none() {
// optimize first service call
this.fut.set(Some(this.service.call(item)));
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
Poll::Ready(res) => {
this.fut.set(None);
slf.handle_result(res, write);
}
Poll::Pending => {
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
}
}
} else {
slf.spawn_service_call(this.service.call(item));
}
}
// drain service responses
DispatcherState::Stop => {
// service may relay on poll_ready for response results
if !this.inner.state.is_dispatcher_ready_err() {
let _ = this.service.poll_ready(cx);
}
if slf.shared.inflight.get() == 0 {
slf.st.set(DispatcherState::Shutdown);
state.shutdown_io();
} else {
state.register_dispatcher(cx.waker());
return Poll::Pending;
}
}
// shutdown service
DispatcherState::Shutdown => {
let err = slf.error.take();
return if this.service.poll_shutdown(cx, err.is_some()).is_ready() {
log::trace!("service shutdown is completed, stop");
Poll::Ready(if let Some(err) = err {
Err(err)
} else {
Ok(())
})
} else {
slf.error.set(err);
Poll::Pending
};
}
}
}
}
}
impl<S, U> DispatcherInner<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + 'static,
{
/// spawn service call
fn spawn_service_call(&self, fut: S::Future) {
self.shared.inflight.set(self.shared.inflight.get() + 1);
let st = self.state.clone();
let shared = self.shared.clone();
crate::rt::spawn(async move {
let item = fut.await;
shared.handle_result(item, st.write());
});
}
fn handle_result(
&self,
item: Result<Option<<U as Encoder>::Item>, S::Error>,
write: Write<'_>,
) {
match write.encode_result(item, &self.shared.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(Either::Left(err)) => {
self.error.set(Some(err));
}
Err(Either::Right(err)) => {
self.shared.error.set(Some(DispatcherError::Encoder(err)))
}
}
}
fn poll_service(
&self,
srv: &S,
cx: &mut Context<'_>,
read: Read<'_>,
) -> Poll<PollService<U>> {
match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
// service is ready, wake io read task
read.resume();
// check keepalive timeout
self.check_keepalive();
// check for errors
Poll::Ready(if let Some(err) = self.shared.error.take() {
log::trace!("error occured, stopping dispatcher");
self.unregister_keepalive();
self.st.set(DispatcherState::Stop);
match err {
DispatcherError::KeepAlive => {
PollService::Item(DispatchItem::KeepAliveTimeout)
}
DispatcherError::Encoder(err) => {
PollService::Item(DispatchItem::EncoderError(err))
}
DispatcherError::Service(err) => {
self.error.set(Some(err));
PollService::ServiceError
}
}
} else if self.state.is_dispatcher_stopped() {
log::trace!("dispatcher is instructed to stop");
self.unregister_keepalive();
// process unhandled data
if let Ok(Some(el)) = read.decode(&self.shared.codec) {
PollService::Item(DispatchItem::Item(el))
} else {
self.st.set(DispatcherState::Stop);
// get io error
if let Some(err) = self.state.take_io_error() {
PollService::Item(DispatchItem::IoError(err))
} else {
PollService::ServiceError
}
}
} else {
PollService::Ready
})
}
// pause io read task
Poll::Pending => {
log::trace!("service is not ready, register dispatch task");
read.pause(cx.waker());
Poll::Pending
}
// handle service readiness error
Poll::Ready(Err(err)) => {
log::trace!("service readiness check failed, stopping");
self.st.set(DispatcherState::Stop);
self.error.set(Some(err));
self.unregister_keepalive();
self.state.dispatcher_ready_err();
Poll::Ready(PollService::ServiceError)
}
}
}
fn ka(&self) -> Seconds {
self.ka_timeout
}
fn ka_enabled(&self) -> bool {
self.ka_timeout.non_zero()
}
/// check keepalive timeout
fn check_keepalive(&self) {
if self.state.is_keepalive() {
log::trace!("keepalive timeout");
if let Some(err) = self.shared.error.take() {
self.shared.error.set(Some(err));
} else {
self.shared.error.set(Some(DispatcherError::KeepAlive));
}
}
}
/// update keep-alive timer
fn update_keepalive(&self) {
if self.ka_enabled() {
let updated = self.timer.now();
if updated != self.ka_updated.get() {
let ka = time::Duration::from(self.ka());
self.timer.register(
updated + ka,
self.ka_updated.get() + ka,
&self.state,
);
self.ka_updated.set(updated);
}
}
}
/// unregister keep-alive timer
fn unregister_keepalive(&self) {
if self.ka_enabled() {
self.timer.unregister(
self.ka_updated.get() + time::Duration::from(self.ka()),
&self.state,
);
}
}
}
#[cfg(test)]
mod tests {
use rand::Rng;
use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex};
use std::time::Duration;
use crate::codec::BytesCodec;
use crate::testing::Io;
use crate::time::{sleep, Millis};
use crate::util::{Bytes, PoolRef, Ready};
use super::*;
impl<S, U> Dispatcher<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Decoder + Encoder + 'static,
<U as Encoder>::Item: 'static,
{
/// Construct new `Dispatcher` instance
pub(crate) fn debug<T, F: IntoService<S>>(
io: T,
codec: U,
service: F,
) -> (Self, State)
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
let timer = Timer::default();
let ka_timeout = Seconds(1);
let ka_updated = timer.now();
let state = State::new();
let io = Rc::new(RefCell::new(io));
let shared = Rc::new(DispatcherShared {
codec: codec,
error: Cell::new(None),
inflight: Cell::new(0),
});
let expire = ka_updated + Duration::from_millis(500);
timer.register(expire, expire, &state);
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
crate::rt::spawn(WriteTask::new(io.clone(), state.clone()));
(
Dispatcher {
service: service.into_service(),
fut: None,
inner: DispatcherInner {
shared,
timer,
ka_timeout,
ka_updated: Cell::new(ka_updated),
state: state.clone(),
error: Cell::new(None),
st: Cell::new(DispatcherState::Processing),
pool: state.memory_pool().pool(),
},
},
state,
)
}
}
#[crate::rt_test]
async fn test_basic() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, _) = Dispatcher::debug(
server,
BytesCodec,
crate::service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
sleep(Millis(50)).await;
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
}
}),
);
crate::rt::spawn(async move {
let _ = disp.await;
});
sleep(Millis(25)).await;
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
client.write("GET /test HTTP/1\r\n\r\n");
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
client.close().await;
assert!(client.is_server_dropped());
}
#[crate::rt_test]
async fn test_sink() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, st) = Dispatcher::debug(
server,
BytesCodec,
crate::service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
}
}),
);
crate::rt::spawn(async move {
let _ = disp.disconnect_timeout(Seconds(1)).await;
});
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
assert!(st
.write()
.encode(Bytes::from_static(b"test"), &mut BytesCodec)
.is_ok());
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test"));
st.close();
sleep(Millis(1100)).await;
assert!(client.is_server_dropped());
}
#[crate::rt_test]
async fn test_err_in_service() {
let (client, server) = Io::create();
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, state) = Dispatcher::debug(
server,
BytesCodec,
crate::service::fn_service(|_: DispatchItem<BytesCodec>| async move {
Err::<Option<Bytes>, _>(())
}),
);
state
.write()
.encode(
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
&mut BytesCodec,
)
.unwrap();
crate::rt::spawn(async move {
let _ = disp.await;
});
// buffer should be flushed
client.remote_buffer_cap(1024);
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
// write side must be closed, dispatcher waiting for read side to close
assert!(client.is_closed());
// close read side
client.close().await;
assert!(client.is_server_dropped());
}
#[crate::rt_test]
async fn test_err_in_service_ready() {
let (client, server) = Io::create();
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");
let counter = Rc::new(Cell::new(0));
struct Srv(Rc<Cell<usize>>);
impl Service for Srv {
type Request = DispatchItem<BytesCodec>;
type Response = Option<Response<BytesCodec>>;
type Error = ();
type Future = Ready<Option<Response<BytesCodec>>, ()>;
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.0.set(self.0.get() + 1);
Poll::Ready(Err(()))
}
fn call(&self, _: DispatchItem<BytesCodec>) -> Self::Future {
Ready::Ok(None)
}
}
let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone()));
state
.write()
.encode(
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
&mut BytesCodec,
)
.unwrap();
crate::rt::spawn(async move {
let _ = disp.await;
});
// buffer should be flushed
client.remote_buffer_cap(1024);
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
// write side must be closed, dispatcher waiting for read side to close
assert!(client.is_closed());
// close read side
client.close().await;
assert!(client.is_server_dropped());
// service must be checked for readiness only once
assert_eq!(counter.get(), 1);
}
#[crate::rt_test]
async fn test_write_backpressure() {
let (client, server) = Io::create();
// do not allow to write to socket
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");
let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
let data2 = data.clone();
let (disp, state) = Dispatcher::debug(
server,
BytesCodec,
crate::service::fn_service(move |msg: DispatchItem<BytesCodec>| {
let data = data2.clone();
async move {
match msg {
DispatchItem::Item(_) => {
data.lock().unwrap().borrow_mut().push(0);
let bytes = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(65_536)
.map(char::from)
.collect::<String>();
return Ok::<_, ()>(Some(Bytes::from(bytes)));
}
DispatchItem::WBackPressureEnabled => {
data.lock().unwrap().borrow_mut().push(1);
}
DispatchItem::WBackPressureDisabled => {
data.lock().unwrap().borrow_mut().push(2);
}
_ => (),
}
Ok(None)
}
}),
);
let pool = PoolRef::default();
pool.set_read_params(8 * 1024, 1024);
pool.set_write_params(16 * 1024, 1024);
crate::rt::spawn(async move {
let _ = disp.await;
});
let buf = client.read_any();
assert_eq!(buf, Bytes::from_static(b""));
client.write("GET /test HTTP/1\r\n\r\n");
sleep(Millis(25)).await;
// buf must be consumed
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
// response message
assert!(!state.write().is_ready());
assert_eq!(state.write().with_buf(|buf| buf.len()), 65536);
client.remote_buffer_cap(10240);
sleep(Millis(50)).await;
assert_eq!(state.write().with_buf(|buf| buf.len()), 55296);
client.remote_buffer_cap(45056);
sleep(Millis(50)).await;
assert_eq!(state.write().with_buf(|buf| buf.len()), 10240);
// backpressure disabled
assert!(state.write().is_ready());
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
}
#[crate::rt_test]
async fn test_keepalive() {
let (client, server) = Io::create();
// do not allow to write to socket
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
let data2 = data.clone();
let (disp, state) = Dispatcher::debug(
server,
BytesCodec,
crate::service::fn_service(move |msg: DispatchItem<BytesCodec>| {
let data = data2.clone();
async move {
match msg {
DispatchItem::Item(bytes) => {
data.lock().unwrap().borrow_mut().push(0);
return Ok::<_, ()>(Some(bytes.freeze()));
}
DispatchItem::KeepAliveTimeout => {
data.lock().unwrap().borrow_mut().push(1);
}
_ => (),
}
Ok(None)
}
}),
);
crate::rt::spawn(async move {
let _ = disp
.keepalive_timeout(Seconds::ZERO)
.keepalive_timeout(Seconds(1))
.await;
});
state.set_disconnect_timeout(Seconds(1));
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
sleep(Millis(3500)).await;
// write side must be closed, dispatcher should fail with keep-alive
let flags = state.flags();
assert!(state.is_io_err());
assert!(state.is_io_shutdown());
assert!(flags.contains(crate::framed::state::Flags::IO_SHUTDOWN));
assert!(client.is_closed());
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
}
#[crate::rt_test]
async fn test_unhandled_data() {
let handled = Arc::new(AtomicBool::new(false));
let handled2 = handled.clone();
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, _) = Dispatcher::debug(
server,
BytesCodec,
crate::service::fn_service(move |msg: DispatchItem<BytesCodec>| {
handled2.store(true, Relaxed);
async move {
sleep(Millis(50)).await;
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
}
}
}),
);
client.close().await;
crate::rt::spawn(async move {
let _ = disp.await;
});
sleep(Millis(50)).await;
assert!(handled.load(Relaxed));
}
}

View file

@ -1,89 +0,0 @@
use std::{fmt, io};
mod dispatcher;
mod read;
mod state;
mod time;
mod write;
pub use self::dispatcher::Dispatcher;
pub use self::read::ReadTask;
pub use self::state::{OnDisconnect, Read, State, Write};
pub use self::time::Timer;
pub use self::write::WriteTask;
use crate::codec::{Decoder, Encoder};
/// Framed transport item
pub enum DispatchItem<U: Encoder + Decoder> {
Item(<U as Decoder>::Item),
/// Write back-pressure enabled
WBackPressureEnabled,
/// Write back-pressure disabled
WBackPressureDisabled,
/// Keep alive timeout
KeepAliveTimeout,
/// Decoder parse error
DecoderError(<U as Decoder>::Error),
/// Encoder parse error
EncoderError(<U as Encoder>::Error),
/// Unexpected io error
IoError(io::Error),
}
impl<U> fmt::Debug for DispatchItem<U>
where
U: Encoder + Decoder,
<U as Decoder>::Item: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
DispatchItem::Item(ref item) => {
write!(fmt, "DispatchItem::Item({:?})", item)
}
DispatchItem::WBackPressureEnabled => {
write!(fmt, "DispatchItem::WBackPressureEnabled")
}
DispatchItem::WBackPressureDisabled => {
write!(fmt, "DispatchItem::WBackPressureDisabled")
}
DispatchItem::KeepAliveTimeout => {
write!(fmt, "DispatchItem::KeepAliveTimeout")
}
DispatchItem::EncoderError(ref e) => {
write!(fmt, "DispatchItem::EncoderError({:?})", e)
}
DispatchItem::DecoderError(ref e) => {
write!(fmt, "DispatchItem::DecoderError({:?})", e)
}
DispatchItem::IoError(ref e) => {
write!(fmt, "DispatchItem::IoError({:?})", e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::BytesCodec;
#[test]
fn test_fmt() {
type T = DispatchItem<BytesCodec>;
let err = T::EncoderError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::Encoder"));
let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::Decoder"));
let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::IoError"));
assert!(format!("{:?}", T::WBackPressureEnabled)
.contains("DispatchItem::WBackPressureEnabled"));
assert!(format!("{:?}", T::WBackPressureDisabled)
.contains("DispatchItem::WBackPressureDisabled"));
assert!(format!("{:?}", T::KeepAliveTimeout)
.contains("DispatchItem::KeepAliveTimeout"));
}
}

View file

@ -1,50 +0,0 @@
use std::{cell::RefCell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::framed::State;
/// Read io task
pub struct ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
io: Rc<RefCell<T>>,
state: State,
}
impl<T> ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
/// Create new read io task
pub fn new(io: Rc<RefCell<T>>, state: State) -> Self {
Self { io, state }
}
}
impl<T> Future for ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.state.is_io_shutdown() {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
} else if self.state.is_io_stop() {
self.state.wake_dispatcher();
Poll::Ready(())
} else if self.state.is_read_paused() {
self.state.register_read_task(cx.waker());
Poll::Pending
} else {
let mut io = self.io.borrow_mut();
if self.state.read_io(&mut *io, cx) {
Poll::Pending
} else {
Poll::Ready(())
}
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,115 +0,0 @@
use std::{cell::RefCell, collections::BTreeMap, rc::Rc, time::Instant};
use crate::framed::State;
use crate::time::{now, sleep, Millis};
use crate::util::HashSet;
pub struct Timer(Rc<RefCell<Inner>>);
struct Inner {
resolution: Millis,
current: Option<Instant>,
notifications: BTreeMap<Instant, HashSet<State>>,
}
impl Inner {
fn new(resolution: Millis) -> Self {
Inner {
resolution,
current: None,
notifications: BTreeMap::default(),
}
}
fn unregister(&mut self, expire: Instant, state: &State) {
if let Some(ref mut states) = self.notifications.get_mut(&expire) {
states.remove(state);
if states.is_empty() {
self.notifications.remove(&expire);
}
}
}
}
impl Clone for Timer {
fn clone(&self) -> Self {
Timer(self.0.clone())
}
}
impl Default for Timer {
fn default() -> Self {
Timer::new(Millis::ONE_SEC)
}
}
impl Timer {
/// Create new timer with resolution in milliseconds
pub fn new(resolution: Millis) -> Timer {
Timer(Rc::new(RefCell::new(Inner::new(resolution))))
}
pub fn register(&self, expire: Instant, previous: Instant, state: &State) {
{
let mut inner = self.0.borrow_mut();
inner.unregister(previous, state);
inner
.notifications
.entry(expire)
.or_insert_with(HashSet::default)
.insert(state.clone());
}
let _ = self.now();
}
pub fn unregister(&self, expire: Instant, state: &State) {
self.0.borrow_mut().unregister(expire, state);
}
/// Get current time. This function has to be called from
/// future's poll method, otherwise it panics.
pub fn now(&self) -> Instant {
let cur = self.0.borrow().current;
if let Some(cur) = cur {
cur
} else {
let now_val = now();
let inner = self.0.clone();
let interval = {
let mut b = inner.borrow_mut();
b.current = Some(now_val);
b.resolution
};
crate::rt::spawn(async move {
sleep(interval).await;
let empty = {
let mut i = inner.borrow_mut();
let now = i.current.take().unwrap_or_else(now);
// notify io dispatcher
while let Some(key) = i.notifications.keys().next() {
let key = *key;
if key <= now {
for st in i.notifications.remove(&key).unwrap() {
st.keepalive_timeout();
}
} else {
break;
}
}
i.notifications.is_empty()
};
// extra tick
if !empty {
let _ = Timer(inner).now();
}
});
now_val
}
}
}

View file

@ -1,166 +0,0 @@
use std::task::{Context, Poll};
use std::{cell::RefCell, future::Future, pin::Pin, rc::Rc};
use crate::codec::{AsyncRead, AsyncWrite, ReadBuf};
use crate::framed::State;
use crate::time::{sleep, Sleep};
#[derive(Debug)]
enum IoWriteState {
Processing,
Shutdown(Option<Sleep>, Shutdown),
}
#[derive(Debug)]
enum Shutdown {
None,
Flushed,
Stopping,
}
/// Write io task
pub struct WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
st: IoWriteState,
io: Rc<RefCell<T>>,
state: State,
}
impl<T> WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
/// Create new write io task
pub fn new(io: Rc<RefCell<T>>, state: State) -> Self {
Self {
io,
state,
st: IoWriteState::Processing,
}
}
/// Shutdown io stream
pub fn shutdown(io: Rc<RefCell<T>>, state: State) -> Self {
let disconnect_timeout = state.get_disconnect_timeout();
let st = IoWriteState::Shutdown(disconnect_timeout.map(sleep), Shutdown::None);
Self { st, io, state }
}
}
impl<T> Future for WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().get_mut();
// IO error occured
if this.state.is_io_err() {
log::trace!("write io is closed");
return Poll::Ready(());
} else if this.state.is_io_stop() {
self.state.wake_dispatcher();
return Poll::Ready(());
}
match this.st {
IoWriteState::Processing => {
if this.state.is_io_shutdown() {
log::trace!("write task is instructed to shutdown");
let disconnect_timeout = this.state.get_disconnect_timeout();
this.st = IoWriteState::Shutdown(
disconnect_timeout.map(sleep),
Shutdown::None,
);
return self.poll(cx);
}
// flush framed instance
match this.state.flush_io(&mut *this.io.borrow_mut(), cx) {
Poll::Pending | Poll::Ready(true) => Poll::Pending,
Poll::Ready(false) => Poll::Ready(()),
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::None => {
// flush write buffer
let result =
this.state.flush_io(&mut *this.io.borrow_mut(), cx);
match result {
Poll::Ready(true) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(false) => {
this.state.set_wr_shutdown_complete();
log::trace!(
"write task is closed with err during flush"
);
return Poll::Ready(());
}
_ => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
match Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx)
{
Poll::Ready(Ok(_)) => {
*st = Shutdown::Stopping;
continue;
}
Poll::Ready(Err(_)) => {
this.state.set_wr_shutdown_complete();
log::trace!(
"write task is closed with err during shutdown"
);
return Poll::Ready(());
}
_ => (),
}
}
Shutdown::Stopping => {
// read until 0 or err
let mut buf = [0u8; 512];
let mut io = this.io.borrow_mut();
loop {
let mut read_buf = ReadBuf::new(&mut buf);
match Pin::new(&mut *io).poll_read(cx, &mut read_buf) {
Poll::Ready(Err(_)) | Poll::Ready(Ok(_))
if read_buf.filled().is_empty() =>
{
this.state.set_wr_shutdown_complete();
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Pending => break,
_ => (),
}
}
}
}
// disconnect timeout
if let Some(ref delay) = delay {
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
}
this.state.set_wr_shutdown_complete();
log::trace!("write task is stopped after delay");
return Poll::Ready(());
}
}
}
}
}

View file

@ -1,11 +1,11 @@
use std::{cell::RefCell, error::Error, fmt, marker::PhantomData, rc::Rc};
use crate::framed::State;
use crate::http::body::MessageBody;
use crate::http::config::{KeepAlive, OnRequest, ServiceConfig};
use crate::http::error::ResponseError;
use crate::http::h1::{Codec, ExpectHandler, H1Service, UpgradeHandler};
use crate::http::h2::H2Service;
use crate::io::{Filter, Io, IoRef};
// use crate::http::h2::H2Service;
use crate::http::helpers::{Data, DataFactory};
use crate::http::request::Request;
use crate::http::response::Response;
@ -18,7 +18,7 @@ use crate::util::PoolId;
///
/// This type can be used to construct an instance of `http service` through a
/// builder-like pattern.
pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
pub struct HttpServiceBuilder<F, S, X = ExpectHandler, U = UpgradeHandler<F>> {
keep_alive: KeepAlive,
client_timeout: Millis,
client_disconnect: Seconds,
@ -26,12 +26,11 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
pool: PoolId,
expect: X,
upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_request: Option<OnRequest<T>>,
_t: PhantomData<(T, S)>,
on_request: Option<OnRequest>,
_t: PhantomData<(F, S)>,
}
impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>> {
impl<F, S> HttpServiceBuilder<F, S, ExpectHandler, UpgradeHandler<F>> {
/// Create instance of `ServiceConfigBuilder`
pub fn new() -> Self {
HttpServiceBuilder {
@ -42,15 +41,15 @@ impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>> {
pool: PoolId::P1,
expect: ExpectHandler,
upgrade: None,
on_connect: None,
on_request: None,
_t: PhantomData,
}
}
}
impl<T, S, X, U> HttpServiceBuilder<T, S, X, U>
impl<F, S, X, U> HttpServiceBuilder<F, S, X, U>
where
F: Filter + 'static,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
@ -61,7 +60,7 @@ where
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<Config = (), Request = (Request, T, State, Codec), Response = ()>,
U: ServiceFactory<Config = (), Request = (Request, Io<F>, Codec), Response = ()>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
@ -122,29 +121,14 @@ where
self
}
#[doc(hidden)]
#[deprecated(since = "0.4.12", note = "Use memory pool config")]
#[inline]
/// Set read/write buffer params
///
/// By default read buffer is 8kb, write buffer is 8kb
pub fn buffer_params(
self,
_max_read_buf_size: u16,
_max_write_buf_size: u16,
_min_buf_size: u16,
) -> Self {
self
}
/// Provide service for `EXPECT: 100-Continue` support.
///
/// Service get called with request that contains `EXPECT` header.
/// Service must return request in case of success, in that case
/// request will be forwarded to main service.
pub fn expect<F, X1>(self, expect: F) -> HttpServiceBuilder<T, S, X1, U>
pub fn expect<XF, X1>(self, expect: XF) -> HttpServiceBuilder<F, S, X1, U>
where
F: IntoServiceFactory<X1>,
XF: IntoServiceFactory<X1>,
X1: ServiceFactory<Config = (), Request = Request, Response = Request>,
X1::Error: ResponseError + 'static,
X1::InitError: fmt::Debug,
@ -159,7 +143,6 @@ where
pool: self.pool,
expect: expect.into_factory(),
upgrade: self.upgrade,
on_connect: self.on_connect,
on_request: self.on_request,
_t: PhantomData,
}
@ -169,12 +152,12 @@ where
///
/// If service is provided then normal requests handling get halted
/// and this service get called with original request and framed object.
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
pub fn upgrade<UF, U1>(self, upgrade: UF) -> HttpServiceBuilder<F, S, X, U1>
where
F: IntoServiceFactory<U1>,
UF: IntoServiceFactory<U1>,
U1: ServiceFactory<
Config = (),
Request = (Request, T, State, Codec),
Request = (Request, Io<F>, Codec),
Response = (),
>,
U1::Error: fmt::Display + Error + 'static,
@ -190,46 +173,29 @@ where
pool: self.pool,
expect: self.expect,
upgrade: Some(upgrade.into_factory()),
on_connect: self.on_connect,
on_request: self.on_request,
_t: PhantomData,
}
}
/// Set on-connect callback.
///
/// It get called once per connection and result of the call
/// get stored to the request's extensions.
pub fn on_connect<F, I>(mut self, f: F) -> Self
where
F: Fn(&T) -> I + 'static,
I: Clone + 'static,
{
self.on_connect = Some(Rc::new(move |io| Box::new(Data(f(io)))));
self
}
/// Set req request callback.
///
/// It get called once per request.
pub fn on_request<Filter, F>(mut self, f: F) -> Self
pub fn on_request<R, FR>(mut self, f: FR) -> Self
where
F: IntoService<Filter>,
Filter: Service<
Request = (Request, Rc<RefCell<T>>),
Response = Request,
Error = Response,
> + 'static,
FR: IntoService<R>,
R: Service<Request = (Request, IoRef), Response = Request, Error = Response>
+ 'static,
{
self.on_request = Some(boxed::service(f.into_service()));
self
}
/// Finish service configuration and create *http service* for HTTP/1 protocol.
pub fn h1<F, B>(self, service: F) -> H1Service<T, S, B, X, U>
pub fn h1<B, SF>(self, service: SF) -> H1Service<F, S, B, X, U>
where
B: MessageBody,
F: IntoServiceFactory<S>,
SF: IntoServiceFactory<S>,
S::Error: ResponseError,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
@ -245,15 +211,15 @@ where
H1Service::with_config(cfg, service.into_factory())
.expect(self.expect)
.upgrade(self.upgrade)
.on_connect(self.on_connect)
.on_request(self.on_request)
}
// pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B>
/// Finish service configuration and create *http service* for HTTP/2 protocol.
pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B>
pub fn h2<B, SF>(self, service: SF) -> H1Service<F, S, B, X, U>
where
B: MessageBody + 'static,
F: IntoServiceFactory<S>,
SF: IntoServiceFactory<S>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
@ -266,14 +232,19 @@ where
self.handshake_timeout,
self.pool,
);
H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect)
// H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect)
H1Service::with_config(cfg, service.into_factory())
.expect(self.expect)
.upgrade(self.upgrade)
.on_request(self.on_request)
}
/// Finish service configuration and create `HttpService` instance.
pub fn finish<F, B>(self, service: F) -> HttpService<T, S, B, X, U>
pub fn finish<B, SF>(self, service: SF) -> HttpService<F, S, B, X, U>
where
B: MessageBody + 'static,
F: IntoServiceFactory<S>,
SF: IntoServiceFactory<S>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
@ -290,7 +261,6 @@ where
HttpService::with_config(cfg, service.into_factory())
.expect(self.expect)
.upgrade(self.upgrade)
.on_connect(self.on_connect)
.on_request(self.on_request)
}
}

View file

@ -42,10 +42,8 @@ impl ClientBuilder {
/// Use custom connector service.
pub fn connector<T>(mut self, connector: T) -> Self
where
T: Service<Request = Connect, Error = ConnectError> + 'static,
T::Response: Connection,
<T::Response as Connection>::Future: 'static,
T::Future: 'static,
T: Service<Request = Connect, Response = Connection, Error = ConnectError>
+ 'static,
{
self.config.connector = Box::new(ConnectorWrapper(connector));
self

View file

@ -1,15 +1,23 @@
use std::{fmt, future::Future, io, net, pin::Pin, task::Context, task::Poll};
use crate::codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
use crate::http::body::Body;
use crate::http::h1::ClientCodec;
use crate::http::{RequestHeadType, ResponseHead};
use crate::io::IoBoxed;
use crate::Service;
use super::error::{ConnectError, SendRequestError};
use super::response::ClientResponse;
use super::{Connect as ClientConnect, Connection};
pub(crate) type TunnelFuture = Pin<
Box<
dyn Future<
Output = Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError>,
>,
>,
>;
pub(super) struct ConnectorWrapper<T>(pub(crate) T);
pub(super) trait Connect {
@ -25,25 +33,12 @@ pub(super) trait Connect {
&self,
head: RequestHeadType,
addr: Option<net::SocketAddr>,
) -> Pin<
Box<
dyn Future<
Output = Result<
(ResponseHead, Framed<BoxedSocket, ClientCodec>),
SendRequestError,
>,
>,
>,
>;
) -> TunnelFuture;
}
impl<T> Connect for ConnectorWrapper<T>
where
T: Service<Request = ClientConnect, Error = ConnectError>,
T::Response: Connection,
<T::Response as Connection>::Io: 'static,
<T::Response as Connection>::Future: 'static,
<T::Response as Connection>::TunnelFuture: 'static,
T: Service<Request = ClientConnect, Response = Connection, Error = ConnectError>,
T::Future: 'static,
{
fn send_request(
@ -73,16 +68,7 @@ where
&self,
head: RequestHeadType,
addr: Option<net::SocketAddr>,
) -> Pin<
Box<
dyn Future<
Output = Result<
(ResponseHead, Framed<BoxedSocket, ClientCodec>),
SendRequestError,
>,
>,
>,
> {
) -> TunnelFuture {
// connect to the host
let fut = self.0.call(ClientConnect {
uri: head.as_ref().uri.clone(),
@ -93,69 +79,7 @@ where
let connection = fut.await?;
// send request
let (head, framed) = connection.open_tunnel(head).await?;
let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io))));
Ok((head, framed))
connection.open_tunnel(head).await
})
}
}
trait AsyncSocket {
fn as_read(&self) -> &(dyn AsyncRead + Unpin);
fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin);
fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin);
}
struct Socket<T: AsyncRead + AsyncWrite + Unpin>(T);
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncSocket for Socket<T> {
fn as_read(&self) -> &(dyn AsyncRead + Unpin) {
&self.0
}
fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) {
&mut self.0
}
fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) {
&mut self.0
}
}
pub struct BoxedSocket(Box<dyn AsyncSocket>);
impl fmt::Debug for BoxedSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BoxedSocket")
}
}
impl AsyncRead for BoxedSocket {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf)
}
}
impl AsyncWrite for BoxedSocket {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_write()).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx)
}
}

View file

@ -8,79 +8,43 @@ use crate::http::h1::ClientCodec;
use crate::http::message::{RequestHeadType, ResponseHead};
use crate::http::payload::Payload;
use crate::http::Protocol;
use crate::io::IoBoxed;
use crate::util::{Bytes, Either, Ready};
use super::error::SendRequestError;
use super::pool::Acquired;
use super::{h1proto, h2proto};
pub(super) enum ConnectionType<Io> {
H1(Io),
pub(super) enum ConnectionType {
H1(IoBoxed),
H2(SendRequest<Bytes>),
}
pub trait Connection {
type Io: AsyncRead + AsyncWrite + Unpin;
type Future: Future<Output = Result<(ResponseHead, Payload), SendRequestError>>;
fn protocol(&self) -> Protocol;
/// Send request and body
fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
self,
head: H,
body: B,
) -> Self::Future;
type TunnelFuture: Future<
Output = Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>;
/// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture;
}
pub(super) trait ConnectionLifetime:
AsyncRead + AsyncWrite + Unpin + 'static
{
/// Close connection
fn close(&mut self);
/// Release connection to the connection pool
fn release(&mut self);
}
#[doc(hidden)]
/// HTTP client connection
pub(super) struct IoConnection<T> {
io: Option<ConnectionType<T>>,
pub struct Connection {
io: Option<ConnectionType>,
created: time::Instant,
pool: Option<Acquired<T>>,
pool: Option<Acquired>,
}
impl<T> fmt::Debug for IoConnection<T>
where
T: fmt::Debug,
{
impl fmt::Debug for Connection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.io {
Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io),
Some(ConnectionType::H1(_)) => write!(f, "H1Connection"),
Some(ConnectionType::H2(_)) => write!(f, "H2Connection"),
None => write!(f, "Connection(Empty)"),
}
}
}
impl<T> IoConnection<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
impl Connection {
pub(super) fn new(
io: ConnectionType<T>,
io: ConnectionType,
created: time::Instant,
pool: Option<Acquired<T>>,
pool: Option<Acquired>,
) -> Self {
IoConnection {
Self {
pool,
created,
io: Some(io),
@ -97,20 +61,11 @@ where
}
}
pub(super) fn into_inner(self) -> (ConnectionType<T>, time::Instant) {
pub(super) fn into_inner(self) -> (ConnectionType, time::Instant) {
(self.io.unwrap(), self.created)
}
}
impl<T> Connection for IoConnection<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Io = T;
type Future =
Pin<Box<dyn Future<Output = Result<(ResponseHead, Payload), SendRequestError>>>>;
fn protocol(&self) -> Protocol {
pub fn protocol(&self) -> Protocol {
match self.io {
Some(ConnectionType::H1(_)) => Protocol::Http1,
Some(ConnectionType::H2(_)) => Protocol::Http2,
@ -118,58 +73,42 @@ where
}
}
fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
pub(super) async fn send_request<
B: MessageBody + 'static,
H: Into<RequestHeadType>,
>(
mut self,
head: H,
body: B,
) -> Self::Future {
) -> Result<(ResponseHead, Payload), SendRequestError> {
match self.io.take().unwrap() {
ConnectionType::H1(io) => Box::pin(h1proto::send_request(
io,
head.into(),
body,
self.created,
self.pool,
)),
ConnectionType::H2(io) => Box::pin(h2proto::send_request(
io,
head.into(),
body,
self.created,
self.pool,
)),
ConnectionType::H1(io) => {
h1proto::send_request(io, head.into(), body, self.created, self.pool)
.await
}
ConnectionType::H2(io) => {
h2proto::send_request(io, head.into(), body, self.created, self.pool)
.await
}
}
}
type TunnelFuture = Either<
Pin<
Box<
dyn Future<
Output = Result<
(ResponseHead, Framed<Self::Io, ClientCodec>),
SendRequestError,
>,
>,
>,
>,
Ready<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
>;
/// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture {
pub(super) async fn open_tunnel<H: Into<RequestHeadType>>(
mut self,
head: H,
) -> Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError> {
match self.io.take().unwrap() {
ConnectionType::H1(io) => {
Either::Left(Box::pin(h1proto::open_tunnel(io, head.into())))
}
ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await,
ConnectionType::H2(io) => {
if let Some(mut pool) = self.pool.take() {
pool.release(IoConnection::new(
pool.release(Connection::new(
ConnectionType::H2(io),
self.created,
None,
));
}
Either::Right(Ready::Err(SendRequestError::TunnelNotSupported))
Err(SendRequestError::TunnelNotSupported)
}
}
}

View file

@ -1,8 +1,8 @@
use std::{rc::Rc, task::Context, task::Poll, time::Duration};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::connect::{Connect as TcpConnect, Connector as TcpConnector};
use crate::http::{Protocol, Uri};
use crate::io::{Filter, Io, IoBoxed};
use crate::service::{apply_fn, boxed, Service};
use crate::time::{Millis, Seconds};
use crate::util::timeout::{TimeoutError, TimeoutService};
@ -16,13 +16,12 @@ use super::Connect;
#[cfg(feature = "openssl")]
use crate::connect::openssl::SslConnector as OpensslConnector;
#[cfg(feature = "rustls")]
use crate::connect::rustls::ClientConfig;
#[cfg(feature = "rustls")]
use std::sync::Arc;
//#[cfg(feature = "rustls")]
//use crate::connect::rustls::ClientConfig;
//#[cfg(feature = "rustls")]
//use std::sync::Arc;
type BoxedConnector =
boxed::BoxService<TcpConnect<Uri>, (Box<dyn Io>, Protocol), ConnectError>;
type BoxedConnector = boxed::BoxService<TcpConnect<Uri>, IoBoxed, ConnectError>;
/// Manages http client network connectivity.
///
@ -47,9 +46,6 @@ pub struct Connector {
ssl_connector: Option<BoxedConnector>,
}
trait Io: AsyncRead + AsyncWrite + Unpin {}
impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
impl Default for Connector {
fn default() -> Self {
Connector::new()
@ -61,7 +57,7 @@ impl Connector {
let conn = Connector {
connector: boxed::service(
TcpConnector::new()
.map(|io| (Box::new(io) as Box<dyn Io>, Protocol::Http1))
.map(|io| io.into_boxed())
.map_err(ConnectError::from),
),
ssl_connector: None,
@ -82,28 +78,28 @@ impl Connector {
.map_err(|e| error!("Cannot set ALPN protocol: {:?}", e));
conn.openssl(ssl.build())
}
#[cfg(all(not(feature = "openssl"), feature = "rustls"))]
{
use rust_tls::{OwnedTrustAnchor, RootCertStore};
// #[cfg(all(not(feature = "openssl"), feature = "rustls"))]
// {
// use rust_tls::{OwnedTrustAnchor, RootCertStore};
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let mut cert_store = RootCertStore::empty();
cert_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}),
);
let mut config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(cert_store)
.with_no_client_auth();
config.alpn_protocols = protos;
conn.rustls(Arc::new(config))
}
// let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
// let mut cert_store = RootCertStore::empty();
// cert_store.add_server_trust_anchors(
// webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
// OwnedTrustAnchor::from_subject_spki_name_constraints(
// ta.subject,
// ta.spki,
// ta.name_constraints,
// )
// }),
// );
// let mut config = ClientConfig::builder()
// .with_safe_defaults()
// .with_root_certificates(cert_store)
// .with_no_client_auth();
// config.alpn_protocols = protos;
// conn.rustls(Arc::new(config))
// }
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
{
conn
@ -126,41 +122,29 @@ impl Connector {
pub fn openssl(self, connector: OpensslConnector) -> Self {
use crate::connect::openssl::OpensslConnector;
const H2: &[u8] = b"h2";
self.secure_connector(OpensslConnector::new(connector).map(|sock| {
let h2 = sock
.ssl()
.selected_alpn_protocol()
.map(|protos| protos.windows(2).any(|w| w == H2))
.unwrap_or(false);
if h2 {
(sock, Protocol::Http2)
} else {
(sock, Protocol::Http1)
}
}))
self.secure_connector(OpensslConnector::new(connector))
}
#[cfg(feature = "rustls")]
/// Use rustls connector for secured connections.
pub fn rustls(self, connector: Arc<ClientConfig>) -> Self {
use crate::connect::rustls::RustlsConnector;
// #[cfg(feature = "rustls")]
// /// Use rustls connector for secured connections.
// pub fn rustls(self, connector: Arc<ClientConfig>) -> Self {
// use crate::connect::rustls::RustlsConnector;
const H2: &[u8] = b"h2";
self.secure_connector(RustlsConnector::new(connector).map(|sock| {
let h2 = sock
.get_ref()
.1
.alpn_protocol()
.map(|protos| protos.windows(2).any(|w| w == H2))
.unwrap_or(false);
if h2 {
(Box::new(sock) as Box<dyn Io>, Protocol::Http2)
} else {
(Box::new(sock) as Box<dyn Io>, Protocol::Http1)
}
}))
}
// const H2: &[u8] = b"h2";
// self.secure_connector(RustlsConnector::new(connector).map(|sock| {
// let h2 = sock
// .get_ref()
// .1
// .alpn_protocol()
// .map(|protos| protos.windows(2).any(|w| w == H2))
// .unwrap_or(false);
// if h2 {
// (Box::new(sock) as Box<dyn Io>, Protocol::Http2)
// } else {
// (Box::new(sock) as Box<dyn Io>, Protocol::Http1)
// }
// }))
// }
/// Set total number of simultaneous connections per type of scheme.
///
@ -206,36 +190,36 @@ impl Connector {
}
/// Use custom connector to open un-secured connections.
pub fn connector<T, U>(mut self, connector: T) -> Self
pub fn connector<T, U, F>(mut self, connector: T) -> Self
where
U: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<
Request = TcpConnect<Uri>,
Response = (U, Protocol),
Response = Io<F>,
Error = crate::connect::ConnectError,
> + 'static,
F: Filter,
{
self.connector = boxed::service(
connector
.map(|(io, proto)| (Box::new(io) as Box<dyn Io>, proto))
.map(|io| io.into_boxed())
.map_err(ConnectError::from),
);
self
}
/// Use custom connector to open secure connections.
pub fn secure_connector<T, U>(mut self, connector: T) -> Self
pub fn secure_connector<T, F>(mut self, connector: T) -> Self
where
U: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<
Request = TcpConnect<Uri>,
Response = (U, Protocol),
Response = Io<F>,
Error = crate::connect::ConnectError,
> + 'static,
F: Filter,
{
self.ssl_connector = Some(boxed::service(
connector
.map(|(io, proto)| (Box::new(io) as Box<dyn Io>, proto))
.map(|io| io.into_boxed())
.map_err(ConnectError::from),
));
self
@ -246,12 +230,13 @@ impl Connector {
/// its combinator chain.
pub fn finish(
self,
) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError>
+ Clone {
let tcp_service = connector(self.connector, self.timeout);
) -> impl Service<Request = Connect, Response = Connection, Error = ConnectError> + Clone
{
let tcp_service =
connector(self.connector, self.timeout, self.disconnect_timeout);
let ssl_pool = if let Some(ssl_connector) = self.ssl_connector {
let srv = connector(ssl_connector, self.timeout);
let srv = connector(ssl_connector, self.timeout, self.disconnect_timeout);
Some(ConnectionPool::new(
srv,
self.conn_lifetime,
@ -279,9 +264,10 @@ impl Connector {
fn connector(
connector: BoxedConnector,
timeout: Millis,
disconnect_timeout: Millis,
) -> impl Service<
Request = Connect,
Response = (Box<dyn Io>, Protocol),
Response = IoBoxed,
Error = ConnectError,
Future = impl Unpin,
> + Unpin {
@ -290,6 +276,10 @@ fn connector(
apply_fn(connector, |msg: Connect, srv| {
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
})
.map(move |io: IoBoxed| {
io.set_disconnect_timeout(disconnect_timeout);
io
})
.map_err(ConnectError::from),
)
.map_err(|e| match e {
@ -298,28 +288,25 @@ fn connector(
})
}
type Pool<T> = ConnectionPool<T, Box<dyn Io>>;
struct InnerConnector<T> {
tcp_pool: Pool<T>,
ssl_pool: Option<Pool<T>>,
tcp_pool: ConnectionPool<T>,
ssl_pool: Option<ConnectionPool<T>>,
}
impl<T> Service for InnerConnector<T>
where
T: Service<
Request = Connect,
Response = (Box<dyn Io>, Protocol),
Error = ConnectError,
> + Unpin
T: Service<Request = Connect, Response = IoBoxed, Error = ConnectError>
+ Unpin
+ 'static,
T::Future: Unpin,
{
type Request = Connect;
type Response = <Pool<T> as Service>::Response;
type Response = <ConnectionPool<T> as Service>::Response;
type Error = ConnectError;
type Future =
Either<<Pool<T> as Service>::Future, Ready<Self::Response, Self::Error>>;
type Future = Either<
<ConnectionPool<T> as Service>::Future,
Ready<Self::Response, Self::Error>,
>;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {

View file

@ -1,28 +1,27 @@
use std::{io, io::Write, pin::Pin, task::Context, task::Poll, time};
use std::{io, io::Write, pin::Pin, task::Context, task::Poll, time::Instant};
use crate::codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
use crate::http::body::{BodySize, MessageBody};
use crate::http::error::PayloadError;
use crate::http::h1;
use crate::http::header::{HeaderMap, HeaderValue, HOST};
use crate::http::message::{RequestHeadType, ResponseHead};
use crate::http::payload::{Payload, PayloadStream};
use crate::io::IoBoxed;
use crate::util::{next, poll_fn, send, BufMut, Bytes, BytesMut};
use crate::{Sink, Stream};
use super::connection::{ConnectionLifetime, ConnectionType, IoConnection};
use super::connection::{Connection, ConnectionType};
use super::error::{ConnectError, SendRequestError};
use super::pool::Acquired;
pub(super) async fn send_request<T, B>(
io: T,
pub(super) async fn send_request<B>(
io: IoBoxed,
mut head: RequestHeadType,
body: B,
created: time::Instant,
pool: Option<Acquired<T>>,
created: Instant,
pool: Option<Acquired>,
) -> Result<(ResponseHead, Payload), SendRequestError>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
B: MessageBody,
{
// set request host header
@ -52,209 +51,130 @@ where
}
}
let io = H1Connection {
created,
pool,
io: Some(io),
};
// let io = H1Connection {
// created,
// pool,
// io: Some(io),
// };
// create Framed and send request
let mut framed = Framed::new(io, h1::ClientCodec::default());
send(&mut framed, (head, body.size()).into()).await?;
// send request
let codec = h1::ClientCodec::default();
io.send((head, body.size()).into(), &codec).await?;
// send request body
match body.size() {
BodySize::None | BodySize::Empty | BodySize::Sized(0) => (),
_ => send_body(body, &mut framed).await?,
_ => {
send_body(body, &io, &codec).await?;
}
};
// read response and init read body
let head = if let Some(result) = next(&mut framed).await {
result.map_err(SendRequestError::from)?
let head = if let Some(result) = io.next(&codec).await? {
result
} else {
return Err(SendRequestError::from(ConnectError::Disconnected));
};
match framed.get_codec().message_type() {
match codec.message_type() {
h1::MessageType::None => {
let force_close = !framed.get_codec().keepalive();
release_connection(framed, force_close);
let force_close = !codec.keepalive();
release_connection(io, force_close, created, pool);
Ok((head, Payload::None))
}
_ => {
let pl: PayloadStream = Box::pin(PlStream::new(framed));
let pl: PayloadStream = Box::pin(PlStream::new(io, codec, created, pool));
Ok((head, pl.into()))
}
}
}
pub(super) async fn open_tunnel<T>(
io: T,
pub(super) async fn open_tunnel(
io: IoBoxed,
head: RequestHeadType,
) -> Result<(ResponseHead, Framed<T, h1::ClientCodec>), SendRequestError>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
) -> Result<(ResponseHead, IoBoxed, h1::ClientCodec), SendRequestError> {
// create Framed and send request
let mut framed = Framed::new(io, h1::ClientCodec::default());
send(&mut framed, (head, BodySize::None).into()).await?;
let codec = h1::ClientCodec::default();
io.send((head, BodySize::None).into(), &codec).await?;
// read response
if let Some(result) = next(&mut framed).await {
let head = result.map_err(SendRequestError::from)?;
Ok((head, framed))
if let Some(head) = io.next(&codec).await? {
Ok((head, io, codec))
} else {
Err(SendRequestError::from(ConnectError::Disconnected))
}
}
/// send request body to the peer
pub(super) async fn send_body<I, B>(
pub(super) async fn send_body<B>(
mut body: B,
framed: &mut Framed<I, h1::ClientCodec>,
io: &IoBoxed,
codec: &h1::ClientCodec,
) -> Result<(), SendRequestError>
where
I: ConnectionLifetime,
B: MessageBody,
{
let mut eof = false;
while !eof {
while !eof && !framed.is_write_buf_full() {
match poll_fn(|cx| body.poll_next_chunk(cx)).await {
Some(result) => {
framed.write(h1::Message::Chunk(Some(result?)))?;
}
None => {
eof = true;
framed.write(h1::Message::Chunk(None))?;
let wrt = io.write();
loop {
match poll_fn(|cx| body.poll_next_chunk(cx)).await {
Some(result) => {
if !wrt.encode(h1::Message::Chunk(Some(result?)), codec)? {
wrt.flush(false).await?;
}
}
}
if !framed.is_write_buf_empty() {
poll_fn(|cx| match framed.flush(cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
Poll::Pending => {
if !framed.is_write_buf_full() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
})
.await?;
None => {
wrt.encode(h1::Message::Chunk(None), codec)?;
break;
}
}
}
poll_fn(|cx| Pin::new(&mut *framed).poll_flush(cx)).await?;
wrt.flush(true).await?;
Ok(())
}
#[doc(hidden)]
/// HTTP client connection
pub(super) struct H1Connection<T> {
io: Option<T>,
created: time::Instant,
pool: Option<Acquired<T>>,
pub(super) struct PlStream {
io: Option<IoBoxed>,
codec: h1::ClientPayloadCodec,
created: Instant,
pool: Option<Acquired>,
}
impl<T> ConnectionLifetime for H1Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Close connection
fn close(&mut self) {
if let Some(mut pool) = self.pool.take() {
if let Some(io) = self.io.take() {
pool.close(IoConnection::new(
ConnectionType::H1(io),
self.created,
None,
));
}
}
}
/// Release this connection to the connection pool
fn release(&mut self) {
if let Some(mut pool) = self.pool.take() {
if let Some(io) = self.io.take() {
pool.release(IoConnection::new(
ConnectionType::H1(io),
self.created,
None,
));
}
}
}
}
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for H1Connection<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for H1Connection<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.io.as_mut().unwrap()).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx)
}
}
pub(super) struct PlStream<Io> {
framed: Option<Framed<Io, h1::ClientPayloadCodec>>,
}
impl<Io: ConnectionLifetime> PlStream<Io> {
fn new(framed: Framed<Io, h1::ClientCodec>) -> Self {
impl PlStream {
fn new(
io: IoBoxed,
codec: h1::ClientCodec,
created: Instant,
pool: Option<Acquired>,
) -> Self {
PlStream {
framed: Some(framed.map_codec(|codec| codec.into_payload_codec())),
io: Some(io),
codec: codec.into_payload_codec(),
created,
pool,
}
}
}
impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
impl Stream for PlStream {
type Item = Result<Bytes, PayloadError>;
fn poll_next(
self: Pin<&mut Self>,
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let mut this = self.as_mut();
match this.framed.as_mut().unwrap().next_item(cx)? {
match this.io.as_ref().unwrap().poll_next(&this.codec, cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(chunk)) => {
if let Some(chunk) = chunk {
Poll::Ready(Some(Ok(chunk)))
} else {
let framed = this.framed.take().unwrap();
let force_close = !framed.get_codec().keepalive();
release_connection(framed, force_close);
let io = this.io.take().unwrap();
let force_close = !this.codec.keepalive();
release_connection(io, force_close, this.created, this.pool.take());
Poll::Ready(None)
}
}
@ -263,14 +183,17 @@ impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
}
}
fn release_connection<T, U>(framed: Framed<T, U>, force_close: bool)
where
T: ConnectionLifetime,
{
let mut parts = framed.into_parts();
if !force_close && parts.read_buf.is_empty() && parts.write_buf.is_empty() {
parts.io.release()
} else {
parts.io.close()
fn release_connection(
io: IoBoxed,
force_close: bool,
created: Instant,
mut pool: Option<Acquired>,
) {
if force_close || io.is_closed() || io.read().with_buf(|buf| !buf.is_empty()) {
if let Some(mut pool) = pool.take() {
pool.close(Connection::new(ConnectionType::H1(io), created, None));
}
} else if let Some(mut pool) = pool.take() {
pool.release(Connection::new(ConnectionType::H1(io), created, None));
}
}

View file

@ -11,19 +11,18 @@ use crate::http::message::{RequestHeadType, ResponseHead};
use crate::http::payload::Payload;
use crate::util::{poll_fn, Bytes};
use super::connection::{ConnectionType, IoConnection};
use super::connection::{Connection, ConnectionType};
use super::error::SendRequestError;
use super::pool::Acquired;
pub(super) async fn send_request<T, B>(
pub(super) async fn send_request<B>(
mut io: SendRequest<Bytes>,
head: RequestHeadType,
body: B,
created: time::Instant,
pool: Option<Acquired<T>>,
pool: Option<Acquired>,
) -> Result<(ResponseHead, Payload), SendRequestError>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
B: MessageBody,
{
trace!("Sending client request: {:?} {:?}", head, body.size());
@ -161,17 +160,17 @@ async fn send_body<B: MessageBody>(
}
// release SendRequest object
fn release<T: AsyncRead + AsyncWrite + Unpin + 'static>(
fn release(
io: SendRequest<Bytes>,
pool: Option<Acquired<T>>,
pool: Option<Acquired>,
created: time::Instant,
close: bool,
) {
if let Some(mut pool) = pool {
if close {
pool.close(IoConnection::new(ConnectionType::H2(io), created, None));
pool.close(Connection::new(ConnectionType::H2(io), created, None));
} else {
pool.release(IoConnection::new(ConnectionType::H2(io), created, None));
pool.release(Connection::new(ConnectionType::H2(io), created, None));
}
}
}

View file

@ -34,7 +34,6 @@ mod test;
pub mod ws;
pub use self::builder::ClientBuilder;
pub use self::connect::BoxedSocket;
pub use self::connection::Connection;
pub use self::connector::Connector;
pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder};
@ -47,7 +46,7 @@ use crate::http::error::HttpError;
use crate::http::{HeaderMap, Method, RequestHead, Uri};
use crate::time::Millis;
use self::connect::{Connect as InnerConnect, ConnectorWrapper};
use self::connect::{Connect as HttpConnect, ConnectorWrapper};
#[derive(Clone)]
pub struct Connect {
@ -76,7 +75,7 @@ pub struct Connect {
pub struct Client(Rc<ClientConfig>);
pub(self) struct ClientConfig {
pub(self) connector: Box<dyn InnerConnect>,
pub(self) connector: Box<dyn HttpConnect>,
pub(self) headers: HeaderMap,
pub(self) timeout: Millis,
}

View file

@ -2,19 +2,20 @@ use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use std::{cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc};
use h2::client::{Builder, Connection, SendRequest};
use h2::client::{Builder, Connection as H2Connection, SendRequest};
use http::uri::Authority;
use crate::channel::pool;
use crate::codec::{AsyncRead, AsyncWrite, ReadBuf};
use crate::http::Protocol;
use crate::io::IoBoxed;
use crate::rt::spawn;
use crate::service::Service;
use crate::task::LocalWaker;
use crate::time::{now, sleep, Millis, Sleep};
use crate::util::{poll_fn, Bytes, HashMap};
use super::connection::{ConnectionType, IoConnection};
use super::connection::{Connection, ConnectionType};
use super::error::ConnectError;
use super::Connect;
@ -29,16 +30,15 @@ impl From<Authority> for Key {
}
}
type Waiter<Io> = pool::Sender<Result<IoConnection<Io>, ConnectError>>;
type WaiterReceiver<Io> = pool::Receiver<Result<IoConnection<Io>, ConnectError>>;
type Waiter = pool::Sender<Result<Connection, ConnectError>>;
type WaiterReceiver = pool::Receiver<Result<Connection, ConnectError>>;
/// Connections pool
pub(super) struct ConnectionPool<T, Io: 'static>(Rc<T>, Rc<RefCell<Inner<Io>>>);
pub(super) struct ConnectionPool<T>(Rc<T>, Rc<RefCell<Inner>>);
impl<T, Io> ConnectionPool<T, Io>
impl<T> ConnectionPool<T>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
T: Service<Request = Connect, Response = IoBoxed, Error = ConnectError>
+ Unpin
+ 'static,
T::Future: Unpin,
@ -73,35 +73,27 @@ where
}
}
impl<T, Io> Drop for ConnectionPool<T, Io>
where
Io: 'static,
{
impl<T> Drop for ConnectionPool<T> {
fn drop(&mut self) {
self.1.borrow().waker.wake();
}
}
impl<T, Io> Clone for ConnectionPool<T, Io>
where
Io: 'static,
{
impl<T> Clone for ConnectionPool<T> {
fn clone(&self) -> Self {
ConnectionPool(self.0.clone(), self.1.clone())
}
}
impl<T, Io> Service for ConnectionPool<T, Io>
impl<T> Service for ConnectionPool<T>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ 'static,
T: Service<Request = Connect, Response = IoBoxed, Error = ConnectError> + 'static,
T::Future: Unpin,
{
type Request = Connect;
type Response = IoConnection<Io>;
type Response = Connection;
type Error = ConnectError;
type Future = Pin<Box<dyn Future<Output = Result<IoConnection<Io>, ConnectError>>>>;
type Future = Pin<Box<dyn Future<Output = Result<Connection, ConnectError>>>>;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -127,11 +119,12 @@ where
};
// acquire connection
match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await {
let result = inner.borrow_mut().acquire(&key);
match result {
// use existing connection
Acquire::Acquired(io, created) => {
trace!("Use existing connection for {:?}", req.uri);
Ok(IoConnection::new(
Ok(Connection::new(
io,
created,
Some(Acquired(key, Some(inner))),
@ -165,31 +158,31 @@ where
}
}
enum Acquire<T> {
Acquired(ConnectionType<T>, Instant),
enum Acquire {
Acquired(ConnectionType, Instant),
Available,
NotAvailable,
}
struct AvailableConnection<Io> {
io: ConnectionType<Io>,
struct AvailableConnection {
io: ConnectionType,
used: Instant,
created: Instant,
}
pub(super) struct Inner<Io> {
pub(super) struct Inner {
conn_lifetime: Duration,
conn_keep_alive: Duration,
disconnect_timeout: Millis,
limit: usize,
acquired: usize,
available: HashMap<Key, VecDeque<AvailableConnection<Io>>>,
waiters: VecDeque<(Key, Connect, Waiter<Io>)>,
available: HashMap<Key, VecDeque<AvailableConnection>>,
waiters: VecDeque<(Key, Connect, Waiter)>,
waker: LocalWaker,
pool: pool::Pool<Result<IoConnection<Io>, ConnectError>>,
pool: pool::Pool<Result<Connection, ConnectError>>,
}
impl<Io> Inner<Io> {
impl Inner {
fn reserve(&mut self) {
self.acquired += 1;
}
@ -199,12 +192,9 @@ impl<Io> Inner<Io> {
}
}
impl<Io> Inner<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
impl Inner {
/// connection is not available, wait
fn wait_for(&mut self, connect: Connect) -> WaiterReceiver<Io> {
fn wait_for(&mut self, connect: Connect) -> WaiterReceiver {
let (tx, rx) = self.pool.channel();
let key: Key = connect.uri.authority().unwrap().clone().into();
self.waiters.push_back((key, connect, tx));
@ -226,7 +216,7 @@ where
}
}
fn acquire(&mut self, key: &Key, cx: &mut Context<'_>) -> Acquire<Io> {
fn acquire(&mut self, key: &Key) -> Acquire {
self.cleanup();
// check limits
@ -249,19 +239,22 @@ where
CloseConnection::spawn(io, self.disconnect_timeout);
}
} else {
let mut io = conn.io;
let mut buf = [0; 2];
let mut read_buf = ReadBuf::new(&mut buf);
if let ConnectionType::H1(ref mut s) = io {
match Pin::new(s).poll_read(cx, &mut read_buf) {
Poll::Pending => (),
Poll::Ready(Ok(_)) if !read_buf.filled().is_empty() => {
if let ConnectionType::H1(io) = io {
CloseConnection::spawn(io, self.disconnect_timeout);
}
continue;
let io = conn.io;
if let ConnectionType::H1(ref s) = io {
if s.is_closed() {
continue;
}
let is_valid = s.read().with_buf(|buf| {
if buf.is_empty() || (buf.len() == 2 && &buf[..] == b"\r\n")
{
buf.clear();
true
} else {
false
}
_ => continue,
});
if !is_valid {
continue;
}
}
return Acquire::Acquired(io, conn.created);
@ -271,7 +264,7 @@ where
Acquire::Available
}
fn release_conn(&mut self, key: &Key, io: ConnectionType<Io>, created: Instant) {
fn release_conn(&mut self, key: &Key, io: ConnectionType, created: Instant) {
self.acquired -= 1;
self.available
.entry(key.clone())
@ -284,7 +277,7 @@ where
self.check_availibility();
}
fn release_close(&mut self, io: ConnectionType<Io>) {
fn release_close(&mut self, io: ConnectionType) {
self.acquired -= 1;
if let ConnectionType::H1(io) = io {
CloseConnection::spawn(io, self.disconnect_timeout);
@ -300,19 +293,14 @@ where
}
}
struct ConnectionPoolSupport<T, Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
struct ConnectionPoolSupport<T> {
connector: T,
inner: Rc<RefCell<Inner<Io>>>,
inner: Rc<RefCell<Inner>>,
}
impl<T, Io> Future for ConnectionPoolSupport<T, Io>
impl<T> Future for ConnectionPoolSupport<T>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
+ Unpin,
T: Service<Request = Connect, Response = IoBoxed, Error = ConnectError> + Unpin,
T::Future: Unpin + 'static,
{
type Output = ();
@ -337,11 +325,11 @@ where
};
let key = key.clone();
match inner.acquire(&key, cx) {
match inner.acquire(&key) {
Acquire::NotAvailable => break,
Acquire::Acquired(io, created) => {
let (key, _, tx) = inner.waiters.pop_front().unwrap();
let _ = tx.send(Ok(IoConnection::new(
let _ = tx.send(Ok(Connection::new(
io,
created,
Some(Acquired(key.clone(), Some(this.inner.clone()))),
@ -363,94 +351,43 @@ where
}
}
struct CloseConnection<T> {
io: T,
struct CloseConnection {
io: IoBoxed,
timeout: Option<Sleep>,
shutdown: bool,
}
impl<T> CloseConnection<T>
where
T: AsyncWrite + AsyncRead + Unpin + 'static,
{
fn spawn(io: T, timeout: Millis) {
spawn(Self {
io,
shutdown: false,
timeout: timeout.map(sleep),
impl CloseConnection {
fn spawn(io: IoBoxed, timeout: Millis) {
spawn(async move {
io.shutdown().await;
});
}
}
impl<T> Future for CloseConnection<T>
where
T: AsyncWrite + AsyncRead + Unpin,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let mut this = self.as_mut();
// shutdown WRITE side
match Pin::new(&mut this.io).poll_shutdown(cx) {
Poll::Ready(Ok(())) => this.shutdown = true,
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(_)) => return Poll::Ready(()),
}
// read until 0 or err
if let Some(ref timeout) = this.timeout {
match timeout.poll_elapsed(cx) {
Poll::Ready(_) => (),
Poll::Pending => {
let mut buf = [0u8; 512];
let mut read_buf = ReadBuf::new(&mut buf);
loop {
match Pin::new(&mut this.io).poll_read(cx, &mut read_buf) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(_)) => return Poll::Ready(()),
Poll::Ready(Ok(_)) => {
if read_buf.filled().is_empty() {
return Poll::Ready(());
}
continue;
}
}
}
}
}
}
Poll::Ready(())
}
}
struct OpenConnection<F, Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
struct OpenConnection<F> {
fut: F,
h2: Option<
Pin<
Box<
dyn Future<
Output = Result<
(SendRequest<Bytes>, Connection<Io, Bytes>),
(SendRequest<Bytes>, H2Connection<Bytes>),
h2::Error,
>,
>,
>,
>,
>,
tx: Option<Waiter<Io>>,
guard: Option<OpenGuard<Io>>,
tx: Option<Waiter>,
guard: Option<OpenGuard>,
}
impl<F, Io> OpenConnection<F, Io>
impl<F> OpenConnection<F>
where
F: Future<Output = Result<(Io, Protocol), ConnectError>> + Unpin + 'static,
Io: AsyncRead + AsyncWrite + Unpin + 'static,
F: Future<Output = Result<IoBoxed, ConnectError>> + Unpin + 'static,
{
fn spawn(key: Key, tx: Waiter<Io>, inner: Rc<RefCell<Inner<Io>>>, fut: F) {
fn spawn(key: Key, tx: Waiter, inner: Rc<RefCell<Inner>>, fut: F) {
spawn(OpenConnection {
fut,
h2: None,
@ -463,10 +400,9 @@ where
}
}
impl<F, Io> Future for OpenConnection<F, Io>
impl<F> Future for OpenConnection<F>
where
F: Future<Output = Result<(Io, Protocol), ConnectError>> + Unpin,
Io: AsyncRead + AsyncWrite + Unpin,
F: Future<Output = Result<IoBoxed, ConnectError>> + Unpin,
{
type Output = ();
@ -478,7 +414,7 @@ where
return match Pin::new(h2).poll(cx) {
Poll::Ready(Ok((snd, connection))) => {
// h2 connection is ready
let conn = IoConnection::new(
let conn = Connection::new(
ConnectionType::H2(snd),
now(),
Some(this.guard.take().unwrap().consume()),
@ -488,7 +424,7 @@ where
conn.release()
}
spawn(async move {
let _ = connection.await;
// let _ = connection.await;
});
Poll::Ready(())
}
@ -511,52 +447,43 @@ where
}
Poll::Ready(())
}
Poll::Ready(Ok((io, proto))) => {
Poll::Ready(Ok(io)) => {
trace!("Connection is established");
// handle http1 proto
if proto == Protocol::Http1 {
let conn = IoConnection::new(
ConnectionType::H1(io),
now(),
Some(this.guard.take().unwrap().consume()),
);
if let Err(Ok(conn)) = this.tx.take().unwrap().send(Ok(conn)) {
// waiter is gone, return connection to pool
conn.release()
}
Poll::Ready(())
} else {
// init http2 handshake
this.h2 = Some(Box::pin(Builder::new().handshake(io)));
self.poll(cx)
//if proto == Protocol::Http1 {
let conn = Connection::new(
ConnectionType::H1(io),
now(),
Some(this.guard.take().unwrap().consume()),
);
if let Err(Ok(conn)) = this.tx.take().unwrap().send(Ok(conn)) {
// waiter is gone, return connection to pool
conn.release()
}
Poll::Ready(())
// } else {
// init http2 handshake
// this.h2 = Some(Box::pin(Builder::new().handshake(io)));
// self.poll(cx)
//}
}
Poll::Pending => Poll::Pending,
}
}
}
struct OpenGuard<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
struct OpenGuard {
key: Key,
inner: Option<Rc<RefCell<Inner<Io>>>>,
inner: Option<Rc<RefCell<Inner>>>,
}
impl<Io> OpenGuard<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn consume(mut self) -> Acquired<Io> {
impl OpenGuard {
fn consume(mut self) -> Acquired {
Acquired(self.key.clone(), self.inner.take())
}
}
impl<Io> Drop for OpenGuard<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
impl Drop for OpenGuard {
fn drop(&mut self) {
if let Some(i) = self.inner.take() {
let mut inner = i.as_ref().borrow_mut();
@ -566,20 +493,17 @@ where
}
}
pub(super) struct Acquired<T>(Key, Option<Rc<RefCell<Inner<T>>>>);
pub(super) struct Acquired(Key, Option<Rc<RefCell<Inner>>>);
impl<T> Acquired<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
pub(super) fn close(&mut self, conn: IoConnection<T>) {
impl Acquired {
pub(super) fn close(&mut self, conn: Connection) {
if let Some(inner) = self.1.take() {
let (io, _) = conn.into_inner();
inner.as_ref().borrow_mut().release_close(io);
}
}
pub(super) fn release(&mut self, conn: IoConnection<T>) {
pub(super) fn release(&mut self, conn: Connection) {
if let Some(inner) = self.1.take() {
let (io, created) = conn.into_inner();
inner
@ -590,7 +514,7 @@ where
}
}
impl<T> Drop for Acquired<T> {
impl Drop for Acquired {
fn drop(&mut self) {
if let Some(inner) = self.1.take() {
inner.borrow_mut().release();

View file

@ -6,17 +6,16 @@ use coo_kie::{Cookie, CookieJar};
use nanorand::{Rng, WyRand};
use crate::codec::{AsyncRead, AsyncWrite, Framed};
use crate::framed::{DispatchItem, Dispatcher, State};
use crate::http::error::HttpError;
use crate::http::header::{self, HeaderName, HeaderValue, AUTHORIZATION};
use crate::http::{ConnectionType, Payload, RequestHead, StatusCode, Uri};
use crate::io::{DefaultFilter, DispatchItem, Dispatcher, Filter, Io, IoBoxed};
use crate::service::{apply_fn, into_service, IntoService, Service};
use crate::util::Either;
use crate::{channel::mpsc, rt, time::timeout, util::sink, util::Ready, ws};
pub use crate::ws::{CloseCode, CloseReason, Frame, Message};
use super::connect::BoxedSocket;
use super::error::{InvalidUrl, SendRequestError, WsClientError};
use super::response::ClientResponse;
use super::ClientConfig;
@ -311,7 +310,7 @@ impl WsRequest {
let fut = self.config.connector.open_tunnel(head.into(), self.addr);
// set request timeout
let (head, framed) = if self.config.timeout.non_zero() {
let (head, io, codec) = if self.config.timeout.non_zero() {
timeout(self.config.timeout, fut)
.await
.map_err(|_| SendRequestError::Timeout)
@ -377,13 +376,12 @@ impl WsRequest {
// response and ws io
Ok(WsConnection::new(
ClientResponse::new(head, Payload::None),
framed.map_codec(|_| {
if server_mode {
ws::Codec::new().max_size(max_size)
} else {
ws::Codec::new().max_size(max_size).client_mode()
}
}),
io,
if server_mode {
ws::Codec::new().max_size(max_size)
} else {
ws::Codec::new().max_size(max_size).client_mode()
},
))
}
}
@ -403,31 +401,20 @@ impl fmt::Debug for WsRequest {
}
}
pub struct WsConnection<Io = BoxedSocket> {
io: Io,
state: State,
pub struct WsConnection {
io: IoBoxed,
codec: ws::Codec,
res: ClientResponse,
}
impl<Io> WsConnection<Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn new(res: ClientResponse, framed: Framed<Io, ws::Codec>) -> Self {
let (io, codec, state) = State::from_framed(framed);
Self {
io,
state,
codec,
res,
}
impl WsConnection {
fn new(res: ClientResponse, io: IoBoxed, codec: ws::Codec) -> Self {
Self { io, codec, res }
}
/// Get ws sink
pub fn sink(&self) -> ws::WsSink {
ws::WsSink::new(self.state.clone(), self.codec.clone())
ws::WsSink::new(self.io.get_ref(), self.codec.clone())
}
/// Get reference to response
@ -458,10 +445,10 @@ where
}
/// Start client websockets service.
pub async fn start<T, F>(self, service: F) -> Result<(), ws::WsError<T::Error>>
pub async fn start<T, U>(self, service: U) -> Result<(), ws::WsError<T::Error>>
where
T: Service<Request = ws::Frame, Response = Option<ws::Message>> + 'static,
F: IntoService<T>,
U: IntoService<T>,
{
let service = apply_fn(
service.into_service().map_err(ws::WsError::Service),
@ -475,21 +462,22 @@ where
DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => {
Either::Right(Ready::Err(ws::WsError::Protocol(e)))
}
DispatchItem::IoError(e) => {
DispatchItem::Disconnect(Some(e)) => {
Either::Right(Ready::Err(ws::WsError::Io(e)))
}
DispatchItem::Disconnect(None) => {
Either::Right(Ready::Err(ws::WsError::Disconnected))
}
},
);
Dispatcher::new(self.io, self.codec, self.state, service, Default::default())
.await
Dispatcher::new(self.io, self.codec, service, Default::default()).await
}
/// Consumes the `WsConnection`, returning it'as underlying I/O framed object
/// and response.
pub fn into_inner(self) -> (ClientResponse, Framed<Io, ws::Codec>) {
let framed = self.state.into_framed(self.io, self.codec);
(self.res, framed)
pub fn into_inner(self) -> (ClientResponse, IoBoxed, ws::Codec) {
(self.res, self.io, self.codec)
}
}

View file

@ -1,7 +1,7 @@
use std::{cell::Cell, cell::RefCell, ptr::copy_nonoverlapping, rc::Rc, time};
use std::{cell::Cell, ptr::copy_nonoverlapping, rc::Rc, time};
use crate::framed::Timer;
use crate::http::{Request, Response};
use crate::io::{IoRef, Timer};
use crate::service::boxed::BoxService;
use crate::time::{sleep, Millis, Seconds, Sleep};
use crate::util::{BytesMut, PoolId};
@ -94,9 +94,9 @@ impl ServiceConfig {
}
}
pub(super) type OnRequest<T> = BoxService<(Request, Rc<RefCell<T>>), Request, Response>;
pub(super) type OnRequest = BoxService<(Request, IoRef), Request, Response>;
pub(super) struct DispatcherConfig<T, S, X, U> {
pub(super) struct DispatcherConfig<S, X, U> {
pub(super) service: S,
pub(super) expect: X,
pub(super) upgrade: Option<U>,
@ -107,16 +107,16 @@ pub(super) struct DispatcherConfig<T, S, X, U> {
pub(super) timer: DateService,
pub(super) timer_h1: Timer,
pub(super) pool: PoolId,
pub(super) on_request: Option<OnRequest<T>>,
pub(super) on_request: Option<OnRequest>,
}
impl<T, S, X, U> DispatcherConfig<T, S, X, U> {
impl<S, X, U> DispatcherConfig<S, X, U> {
pub(super) fn new(
cfg: ServiceConfig,
service: S,
expect: X,
upgrade: Option<U>,
on_request: Option<OnRequest<T>>,
on_request: Option<OnRequest>,
) -> Self {
DispatcherConfig {
service,
@ -148,10 +148,6 @@ impl<T, S, X, U> DispatcherConfig<T, S, X, U> {
self.keep_alive
.map(|t| self.timer.now() + time::Duration::from(t))
}
pub(super) fn now(&self) -> time::Instant {
self.timer.now()
}
}
const DATE_VALUE_LENGTH_HDR: usize = 39;

View file

@ -322,7 +322,7 @@ impl MessageType for ResponseHead {
Err(ParseError::TooLarge)
} else {
Ok(None)
}
};
}
}
};

View file

@ -1,14 +1,10 @@
//! Framed transport dispatcher
use std::task::{Context, Poll};
use std::{
cell::RefCell, error::Error, fmt, future::Future, marker, net, pin::Pin, rc::Rc,
time,
};
use std::{error::Error, fmt, future::Future, marker, net, pin::Pin, rc::Rc, time};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::framed::{ReadTask, State as IoState, WriteTask};
use crate::io::{Filter, Io, IoRef};
use crate::service::Service;
use crate::util::Bytes;
use crate::{time::now, util::Bytes, util::Either};
use crate::http;
use crate::http::body::{BodySize, MessageBody, ResponseBody};
@ -37,19 +33,24 @@ bitflags::bitflags! {
pin_project_lite::pin_project! {
/// Dispatcher for HTTP/1.1 protocol
pub struct Dispatcher<T, S: Service, B, X: Service, U: Service> {
pub struct Dispatcher<F, S: Service, B, X: Service, U: Service> {
#[pin]
call: CallState<S, X, U>,
st: State<B>,
inner: DispatcherInner<T, S, B, X, U>,
inner: DispatcherInner<F, S, B, X, U>,
}
}
#[derive(derive_more::Display)]
enum State<B> {
Call,
ReadRequest,
ReadPayload,
SendPayload { body: ResponseBody<B> },
#[display(fmt = "State::SendPayload")]
SendPayload {
body: ResponseBody<B>,
},
#[display(fmt = "State::Upgrade")]
Upgrade(Option<Request>),
Stop,
}
@ -65,17 +66,15 @@ pin_project_lite::pin_project! {
}
}
struct DispatcherInner<T, S, B, X, U> {
io: Option<Rc<RefCell<T>>>,
struct DispatcherInner<F, S, B, X, U> {
io: Option<Io<F>>,
flags: Flags,
codec: Codec,
config: Rc<DispatcherConfig<T, S, X, U>>,
state: IoState,
state: IoRef,
config: Rc<DispatcherConfig<S, X, U>>,
expire: time::Instant,
error: Option<DispatchError>,
payload: Option<(PayloadDecoder, PayloadSender)>,
peer_addr: Option<net::SocketAddr>,
on_connect_data: Option<Box<dyn DataFactory>>,
_t: marker::PhantomData<(S, B)>,
}
@ -93,42 +92,33 @@ enum WritePayloadStatus<B> {
Continue,
}
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
impl<F, S, B, X, U> Dispatcher<F, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
F: Filter + 'static,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
U: Service<Request = (Request, Io<F>, Codec), Response = ()>,
U::Error: Error + fmt::Display,
{
/// Construct new `Dispatcher` instance with outgoing messages stream.
pub(in crate::http) fn new(
io: T,
config: Rc<DispatcherConfig<T, S, X, U>>,
peer_addr: Option<net::SocketAddr>,
on_connect_data: Option<Box<dyn DataFactory>>,
io: Io<F>,
config: Rc<DispatcherConfig<S, X, U>>,
) -> Self {
let mut expire = now();
let state = io.get_ref();
let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled());
let state = IoState::with_memory_pool(config.pool.into());
state.set_disconnect_timeout(config.client_disconnect);
let mut expire = config.timer_h1.now();
let io = Rc::new(RefCell::new(io));
// slow-request timer
if config.client_timeout.non_zero() {
expire += std::time::Duration::from(config.client_timeout);
expire += time::Duration::from(config.client_timeout);
config.timer_h1.register(expire, expire, &state);
}
// start support io tasks
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
crate::rt::spawn(WriteTask::new(io.clone(), state.clone()));
Dispatcher {
call: CallState::None,
st: State::ReadRequest,
@ -138,27 +128,25 @@ where
error: None,
payload: None,
codec,
config,
state,
config,
expire,
peer_addr,
on_connect_data,
_t: marker::PhantomData,
},
}
}
}
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
impl<F, S, B, X, U> Future for Dispatcher<F, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
F: Filter,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
U: Service<Request = (Request, Io<F>, Codec), Response = ()>,
U::Error: Error + fmt::Display + 'static,
{
type Output = Result<(), DispatchError>;
@ -204,7 +192,6 @@ where
)
});
if this.inner.flags.contains(Flags::UPGRADE) {
this.inner.state.stop_io(cx.waker());
*this.st = State::Upgrade(Some(req));
return Poll::Pending;
} else {
@ -292,7 +279,6 @@ where
log::trace!("keep-alive timeout, close connection");
}
*this.st = State::Stop;
continue;
}
@ -307,7 +293,7 @@ where
req,
pl
);
req.head_mut().peer_addr = this.inner.peer_addr;
req.head_mut().io = Some(this.inner.state.clone());
// configure request payload
let upgrade = match pl {
@ -340,16 +326,9 @@ where
);
}
// set on_connect data
if let Some(ref on_connect) = this.inner.on_connect_data
{
on_connect.set(&mut req.extensions_mut());
}
if upgrade {
// Handle UPGRADE request
log::trace!("prep io for upgrade handler");
this.inner.state.stop_io(cx.waker());
*this.st = State::Upgrade(Some(req));
return Poll::Pending;
} else {
@ -361,11 +340,7 @@ where
CallState::Filter {
fut: f.call((
req,
this.inner
.io
.as_ref()
.unwrap()
.clone(),
this.inner.state.clone(),
)),
}
} else if req.head().expect() {
@ -390,13 +365,13 @@ where
if this.inner.flags.contains(Flags::STARTED)
&& (!this.inner.flags.contains(Flags::KEEPALIVE)
|| !this.inner.codec.keepalive_enabled()
|| this.inner.state.is_io_err())
|| !this.inner.state.is_io_open())
{
*this.st = State::Stop;
this.inner.state.dispatcher_stopped();
this.inner.state.stop_dispatcher();
continue;
}
this.inner.state.read().wake(cx.waker());
let _ = read.poll_ready(cx);
return Poll::Pending;
}
Err(err) => {
@ -418,13 +393,13 @@ where
*this.st = State::Stop;
continue;
}
this.inner.state.register_dispatcher(cx.waker());
let _ = read.poll_ready(cx);
return Poll::Pending;
}
}
// consume request's payload
State::ReadPayload => {
if this.inner.state.is_io_err() {
if !this.inner.state.is_io_open() {
*this.st = State::Stop;
} else {
loop {
@ -445,7 +420,7 @@ where
}
// send response body
State::SendPayload { ref mut body } => {
if this.inner.state.is_io_err() {
if !this.inner.state.is_io_open() {
*this.st = State::Stop;
} else {
this.inner.poll_read_payload(cx);
@ -459,7 +434,7 @@ where
this.inner
.state
.write()
.enable_backpressure(Some(cx.waker()));
.enable_backpressure(Some(cx));
return Poll::Pending;
}
WritePayloadStatus::Continue => (),
@ -470,50 +445,48 @@ where
}
// stop io tasks and call upgrade service
State::Upgrade(ref mut req) => {
// check if all io tasks have been stopped
let io = if Rc::strong_count(this.inner.io.as_ref().unwrap()) == 1 {
if let Ok(io) = Rc::try_unwrap(this.inner.io.take().unwrap()) {
io.into_inner()
} else {
return Poll::Ready(Err(DispatchError::InternalError));
}
} else {
// wait next task stop
this.inner.state.register_dispatcher(cx.waker());
return Poll::Pending;
};
log::trace!("initate upgrade handling");
let io = this.inner.io.take().unwrap();
let req = req.take().unwrap();
*this.st = State::Call;
this.inner.state.reset_io_stop();
// Handle UPGRADE request
this.call.set(CallState::Upgrade {
fut: this.inner.config.upgrade.as_ref().unwrap().call((
req,
io,
this.inner.state.clone(),
this.inner.codec.clone(),
)),
});
}
// prepare to shutdown
State::Stop => {
this.inner.state.shutdown_io();
this.inner.unregister_keepalive();
if this
.inner
.io
.as_ref()
.unwrap()
.poll_shutdown(cx)?
.is_ready()
{
// get io error
if this.inner.error.is_none() {
this.inner.error =
this.inner.state.take_error().map(DispatchError::Io);
}
// get io error
if this.inner.error.is_none() {
this.inner.error =
this.inner.state.take_io_error().map(DispatchError::Io);
}
return Poll::Ready(if let Some(err) = this.inner.error.take() {
Err(err)
return Poll::Ready(
if let Some(err) = this.inner.error.take() {
Err(err)
} else {
Ok(())
},
);
} else {
Ok(())
});
return Poll::Pending;
}
}
}
}
@ -536,8 +509,7 @@ where
fn reset_keepalive(&mut self) {
// re-register keep-alive
if self.flags.contains(Flags::KEEPALIVE) && self.config.keep_alive.non_zero() {
let expire = self.config.timer_h1.now()
+ std::time::Duration::from(self.config.keep_alive);
let expire = now() + time::Duration::from(self.config.keep_alive);
if expire != self.expire {
self.config
.timer_h1
@ -571,11 +543,11 @@ where
}
fn send_response(&mut self, msg: Response<()>, body: ResponseBody<B>) -> State<B> {
trace!("Sending response: {:?} body: {:?}", msg, body.size());
trace!("sending response: {:?} body: {:?}", msg, body.size());
// we dont need to process responses if socket is disconnected
// but we still want to handle requests with app service
// so we skip response processing for droppped connection
if !self.state.is_io_err() {
if self.state.is_io_open() {
let result = self
.state
.write()
@ -617,7 +589,7 @@ where
) -> WritePayloadStatus<B> {
match item {
Some(Ok(item)) => {
trace!("Got response chunk: {:?}", item.len());
trace!("got response chunk: {:?}", item.len());
match self
.state
.write()
@ -637,7 +609,7 @@ where
}
}
None => {
trace!("Response payload eof");
trace!("response payload eof");
if let Err(err) =
self.state.write().encode(Message::Chunk(None), &self.codec)
{
@ -653,7 +625,7 @@ where
}
}
Some(Err(e)) => {
trace!("Error during response body poll: {:?}", e);
trace!("error during response body poll: {:?}", e);
self.error = Some(DispatchError::ResponsePayload(e));
WritePayloadStatus::Next(State::Stop)
}
@ -686,13 +658,13 @@ where
break;
}
Ok(None) => {
if self.state.is_io_err() {
if !self.state.is_io_open() {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
self.error = Some(ParseError::Incomplete.into());
return ReadPayloadStatus::Dropped;
} else {
read.wake(cx.waker());
let _ = read.poll_ready(cx);
break;
}
}
@ -737,6 +709,7 @@ mod tests {
use crate::http::config::{DispatcherConfig, ServiceConfig};
use crate::http::h1::{ClientCodec, ExpectHandler, UpgradeHandler};
use crate::http::{body, Request, ResponseHead, StatusCode};
use crate::io::{self as nio, DefaultFilter};
use crate::service::{boxed, fn_service, IntoService};
use crate::util::{lazy, next, Bytes, BytesMut};
use crate::{codec::Decoder, testing::Io, time::sleep, time::Millis};
@ -747,7 +720,7 @@ mod tests {
pub(crate) fn h1<F, S, B>(
stream: Io,
service: F,
) -> Dispatcher<Io, S, B, ExpectHandler, UpgradeHandler<Io>>
) -> Dispatcher<DefaultFilter, S, B, ExpectHandler, UpgradeHandler<DefaultFilter>>
where
F: IntoService<S>,
S: Service<Request = Request>,
@ -756,7 +729,7 @@ mod tests {
B: MessageBody,
{
Dispatcher::new(
stream,
nio::Io::new(stream),
Rc::new(DispatcherConfig::new(
ServiceConfig::default(),
service.into_service(),
@ -764,8 +737,6 @@ mod tests {
None,
None,
)),
None,
None,
)
}
@ -777,20 +748,22 @@ mod tests {
S::Response: Into<Response<B>>,
B: MessageBody + 'static,
{
crate::rt::spawn(
Dispatcher::<Io, S, B, ExpectHandler, UpgradeHandler<Io>>::new(
stream,
Rc::new(DispatcherConfig::new(
ServiceConfig::default(),
service.into_service(),
ExpectHandler,
None,
None,
)),
crate::rt::spawn(Dispatcher::<
DefaultFilter,
S,
B,
ExpectHandler,
UpgradeHandler<DefaultFilter>,
>::new(
nio::Io::new(stream),
Rc::new(DispatcherConfig::new(
ServiceConfig::default(),
service.into_service(),
ExpectHandler,
None,
None,
),
);
)),
));
}
fn load(decoder: &mut ClientCodec, buf: &mut BytesMut) -> ResponseHead {
@ -806,7 +779,7 @@ mod tests {
let data = Rc::new(Cell::new(false));
let data2 = data.clone();
let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<Io>>::new(
server,
nio::Io::new(server),
Rc::new(DispatcherConfig::new(
ServiceConfig::default(),
fn_service(|_| {
@ -821,8 +794,6 @@ mod tests {
},
))),
)),
None,
None,
);
sleep(Millis(50)).await;

View file

@ -3,35 +3,33 @@ use std::{
task,
};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::framed::State as IoState;
use crate::http::body::MessageBody;
use crate::http::config::{DispatcherConfig, OnRequest, ServiceConfig};
use crate::http::error::{DispatchError, ResponseError};
use crate::http::helpers::DataFactory;
use crate::http::request::Request;
use crate::http::response::Response;
use crate::io::{DefaultFilter, Filter, Io, IoRef};
use crate::service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
use crate::{rt::net::TcpStream, time::Millis, util::Pool};
use crate::{time::Millis, util::Pool};
use super::codec::Codec;
use super::dispatcher::Dispatcher;
use super::{ExpectHandler, UpgradeHandler};
/// `ServiceFactory` implementation for HTTP1 transport
pub struct H1Service<T, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
pub struct H1Service<F, S, B, X = ExpectHandler, U = UpgradeHandler<F>> {
srv: S,
cfg: ServiceConfig,
expect: X,
upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_request: RefCell<Option<OnRequest<T>>>,
on_request: RefCell<Option<OnRequest>>,
#[allow(dead_code)]
handshake_timeout: Millis,
_t: marker::PhantomData<(T, B)>,
_t: marker::PhantomData<(F, B)>,
}
impl<T, S, B> H1Service<T, S, B>
impl<F, S, B> H1Service<F, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
@ -40,15 +38,14 @@ where
B: MessageBody,
{
/// Create new `HttpService` instance with config.
pub(crate) fn with_config<F: IntoServiceFactory<S>>(
pub(crate) fn with_config<U: IntoServiceFactory<S>>(
cfg: ServiceConfig,
service: F,
service: U,
) -> Self {
H1Service {
srv: service.into_factory(),
expect: ExpectHandler,
upgrade: None,
on_connect: None,
on_request: RefCell::new(None),
handshake_timeout: cfg.0.ssl_handshake_timeout,
_t: marker::PhantomData,
@ -57,53 +54,14 @@ where
}
}
impl<S, B, X, U> H1Service<TcpStream, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Future: 'static,
B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, TcpStream, IoState, Codec),
Response = (),
>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
{
/// Create simple tcp stream service
pub fn tcp(
self,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = DispatchError,
InitError = (),
> {
pipeline_factory(|io: TcpStream| async move {
let peer_addr = io.peer_addr().ok();
Ok((io, peer_addr))
})
.and_then(self)
}
}
#[cfg(feature = "openssl")]
mod openssl {
use super::*;
use crate::server::openssl::{Acceptor, SslAcceptor, SslStream};
use crate::server::openssl::{Acceptor, SslAcceptor, SslFilter};
use crate::server::SslError;
impl<S, B, X, U> H1Service<SslStream<TcpStream>, S, B, X, U>
impl<S, B, X, U> H1Service<SslFilter<DefaultFilter>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
@ -117,7 +75,7 @@ mod openssl {
X::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, SslStream<TcpStream>, IoState, Codec),
Request = (Request, Io<SslFilter<DefaultFilter>>, Codec),
Response = (),
>,
U::Error: fmt::Display + Error + 'static,
@ -130,7 +88,7 @@ mod openssl {
acceptor: SslAcceptor,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Request = Io,
Response = (),
Error = SslError<DispatchError>,
InitError = (),
@ -141,71 +99,68 @@ mod openssl {
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(|io: SslStream<TcpStream>| async move {
let peer_addr = io.get_ref().peer_addr().ok();
Ok((io, peer_addr))
})
.and_then(self.map_err(SslError::Service))
}
}
}
#[cfg(feature = "rustls")]
mod rustls {
use super::*;
use crate::server::rustls::{Acceptor, ServerConfig, TlsStream};
use crate::server::SslError;
use std::fmt;
// #[cfg(feature = "rustls")]
// mod rustls {
// use super::*;
// use crate::server::rustls::{Acceptor, ServerConfig, TlsStream};
// use crate::server::SslError;
// use std::fmt;
impl<S, B, X, U> H1Service<TlsStream<TcpStream>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>>,
S::Future: 'static,
B: MessageBody,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, TlsStream<TcpStream>, IoState, Codec),
Response = (),
>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
{
/// Create rustls based service
pub fn rustls(
self,
config: ServerConfig,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = SslError<DispatchError>,
InitError = (),
> {
pipeline_factory(
Acceptor::new(config)
.timeout(self.handshake_timeout)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(|io: TlsStream<TcpStream>| async move {
let peer_addr = io.get_ref().0.peer_addr().ok();
Ok((io, peer_addr))
})
.and_then(self.map_err(SslError::Service))
}
}
}
// impl<S, B, X, U> H1Service<TlsStream<TcpStream>, S, B, X, U>
// where
// S: ServiceFactory<Config = (), Request = Request>,
// S::Error: ResponseError + 'static,
// S::InitError: fmt::Debug,
// S::Response: Into<Response<B>>,
// S::Future: 'static,
// B: MessageBody,
// X: ServiceFactory<Config = (), Request = Request, Response = Request>,
// X::Error: ResponseError + 'static,
// X::InitError: fmt::Debug,
// X::Future: 'static,
// U: ServiceFactory<
// Config = (),
// Request = (Request, TlsStream<TcpStream>, IoState, Codec),
// Response = (),
// >,
// U::Error: fmt::Display + Error + 'static,
// U::InitError: fmt::Debug,
// U::Future: 'static,
// {
// /// Create rustls based service
// pub fn rustls(
// self,
// config: ServerConfig,
// ) -> impl ServiceFactory<
// Config = (),
// Request = TcpStream,
// Response = (),
// Error = SslError<DispatchError>,
// InitError = (),
// > {
// pipeline_factory(
// Acceptor::new(config)
// .timeout(self.handshake_timeout)
// .map_err(SslError::Ssl)
// .map_init_err(|_| panic!()),
// )
// .and_then(|io: TlsStream<TcpStream>| async move {
// let peer_addr = io.get_ref().0.peer_addr().ok();
// Ok((io, peer_addr))
// })
// .and_then(self.map_err(SslError::Service))
// }
// }
// }
impl<T, S, B, X, U> H1Service<T, S, B, X, U>
impl<F, S, B, X, U> H1Service<F, S, B, X, U>
where
F: Filter,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
@ -213,7 +168,7 @@ where
S::Future: 'static,
B: MessageBody,
{
pub fn expect<X1>(self, expect: X1) -> H1Service<T, S, B, X1, U>
pub fn expect<X1>(self, expect: X1) -> H1Service<F, S, B, X1, U>
where
X1: ServiceFactory<Request = Request, Response = Request>,
X1::Error: ResponseError + 'static,
@ -225,16 +180,15 @@ where
cfg: self.cfg,
srv: self.srv,
upgrade: self.upgrade,
on_connect: self.on_connect,
on_request: self.on_request,
handshake_timeout: self.handshake_timeout,
_t: marker::PhantomData,
}
}
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, S, B, X, U1>
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<F, S, B, X, U1>
where
U1: ServiceFactory<Request = (Request, T, IoState, Codec), Response = ()>,
U1: ServiceFactory<Request = (Request, Io<F>, Codec), Response = ()>,
U1::Error: fmt::Display + Error + 'static,
U1::InitError: fmt::Debug,
U1::Future: 'static,
@ -244,34 +198,24 @@ where
cfg: self.cfg,
srv: self.srv,
expect: self.expect,
on_connect: self.on_connect,
on_request: self.on_request,
handshake_timeout: self.handshake_timeout,
_t: marker::PhantomData,
}
}
/// Set on connect callback.
pub(crate) fn on_connect(
mut self,
f: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
) -> Self {
self.on_connect = f;
self
}
/// Set req request callback.
///
/// It get called once per request.
pub(crate) fn on_request(self, f: Option<OnRequest<T>>) -> Self {
pub(crate) fn on_request(self, f: Option<OnRequest>) -> Self {
*self.on_request.borrow_mut() = f;
self
}
}
impl<T, S, B, X, U> ServiceFactory for H1Service<T, S, B, X, U>
impl<F, S, B, X, U> ServiceFactory for H1Service<F, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
F: Filter + 'static,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
@ -282,28 +226,23 @@ where
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, T, IoState, Codec),
Response = (),
>,
U: ServiceFactory<Config = (), Request = (Request, Io<F>, Codec), Response = ()>,
U::Error: fmt::Display + Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
{
type Config = ();
type Request = (T, Option<net::SocketAddr>);
type Request = Io<F>;
type Response = ();
type Error = DispatchError;
type InitError = ();
type Service = H1ServiceHandler<T, S::Service, B, X::Service, U::Service>;
type Service = H1ServiceHandler<F, S::Service, B, X::Service, U::Service>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Service, Self::InitError>>>>;
fn new_service(&self, _: ()) -> Self::Future {
let fut = self.srv.new_service(());
let fut_ex = self.expect.new_service(());
let fut_upg = self.upgrade.as_ref().map(|f| f.new_service(()));
let on_connect = self.on_connect.clone();
let on_request = self.on_request.borrow_mut().take();
let cfg = self.cfg.clone();
@ -331,7 +270,6 @@ where
Ok(H1ServiceHandler {
pool,
config,
on_connect,
_t: marker::PhantomData,
})
})
@ -339,29 +277,28 @@ where
}
/// `Service` implementation for HTTP1 transport
pub struct H1ServiceHandler<T, S: Service, B, X: Service, U: Service> {
pub struct H1ServiceHandler<F, S: Service, B, X: Service, U: Service> {
pool: Pool,
config: Rc<DispatcherConfig<T, S, X, U>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: marker::PhantomData<(T, B)>,
config: Rc<DispatcherConfig<S, X, U>>,
_t: marker::PhantomData<(F, B)>,
}
impl<T, S, B, X, U> Service for H1ServiceHandler<T, S, B, X, U>
impl<F, S, B, X, U> Service for H1ServiceHandler<F, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
F: Filter + 'static,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
U: Service<Request = (Request, Io<F>, Codec), Response = ()>,
U::Error: fmt::Display + Error + 'static,
{
type Request = (T, Option<net::SocketAddr>);
type Request = Io<F>;
type Response = ();
type Error = DispatchError;
type Future = Dispatcher<T, S, B, X, U>;
type Future = Dispatcher<F, S, B, X, U>;
fn poll_ready(
&self,
@ -429,9 +366,7 @@ where
}
}
fn call(&self, (io, addr): Self::Request) -> Self::Future {
let on_connect = self.on_connect.as_ref().map(|f| f(&io));
Dispatcher::new(io, self.config.clone(), addr, on_connect)
fn call(&self, io: Self::Request) -> Self::Future {
Dispatcher::new(io, self.config.clone())
}
}

View file

@ -2,16 +2,17 @@ use std::{io, marker::PhantomData, task::Context, task::Poll};
use crate::http::h1::Codec;
use crate::http::request::Request;
use crate::{framed::State, util::Ready, Service, ServiceFactory};
use crate::io::Io;
use crate::{util::Ready, Service, ServiceFactory};
pub struct UpgradeHandler<T>(PhantomData<T>);
pub struct UpgradeHandler<F>(PhantomData<F>);
impl<T> ServiceFactory for UpgradeHandler<T> {
impl<F> ServiceFactory for UpgradeHandler<F> {
type Config = ();
type Request = (Request, T, State, Codec);
type Request = (Request, Io<F>, Codec);
type Response = ();
type Error = io::Error;
type Service = UpgradeHandler<T>;
type Service = UpgradeHandler<F>;
type InitError = io::Error;
type Future = Ready<Self::Service, Self::InitError>;
@ -21,8 +22,8 @@ impl<T> ServiceFactory for UpgradeHandler<T> {
}
}
impl<T> Service for UpgradeHandler<T> {
type Request = (Request, T, State, Codec);
impl<F> Service for UpgradeHandler<F> {
type Request = (Request, Io<F>, Codec);
type Response = ();
type Error = io::Error;
type Future = Ready<Self::Response, Self::Error>;

View file

@ -4,11 +4,11 @@ use std::task::{Context, Poll};
use h2::RecvStream;
mod dispatcher;
mod service;
//mod dispatcher;
//mod service;
pub use self::dispatcher::Dispatcher;
pub use self::service::H2Service;
//pub use self::dispatcher::Dispatcher;
//pub use self::service::H2Service;
use crate::{http::error::PayloadError, util::Bytes, Stream};
/// H2 receive stream

View file

@ -6,6 +6,7 @@ use bitflags::bitflags;
use crate::http::header::HeaderMap;
use crate::http::{header, Method, StatusCode, Uri, Version};
use crate::io::IoRef;
use crate::util::Extensions;
/// Represents various types of connection
@ -45,19 +46,19 @@ pub struct RequestHead {
pub version: Version,
pub headers: HeaderMap,
pub extensions: RefCell<Extensions>,
pub peer_addr: Option<net::SocketAddr>,
pub io: Option<IoRef>,
pub(super) flags: Flags,
}
impl Default for RequestHead {
fn default() -> RequestHead {
RequestHead {
io: None,
uri: Uri::default(),
method: Method::default(),
version: Version::HTTP_11,
headers: HeaderMap::with_capacity(16),
flags: Flags::empty(),
peer_addr: None,
extensions: RefCell::new(Extensions::new()),
}
}
@ -65,6 +66,7 @@ impl Default for RequestHead {
impl Head for RequestHead {
fn clear(&mut self) {
self.io = None;
self.flags = Flags::empty();
self.headers.clear();
self.extensions.get_mut().clear();

View file

@ -6,6 +6,7 @@ use crate::http::header::HeaderMap;
use crate::http::httpmessage::HttpMessage;
use crate::http::message::{Message, RequestHead};
use crate::http::payload::Payload;
use crate::io::IoRef;
use crate::util::Extensions;
/// Request
@ -126,13 +127,21 @@ impl Request {
self.head().method == Method::CONNECT
}
/// Io reference for current connection
#[inline]
pub fn io(&self) -> Option<&IoRef> {
self.head().io.as_ref()
}
/// Peer socket address
///
/// Peer address is actual socket address, if proxy is used in front of
/// ntex http server, then peer address would be address of this proxy.
#[inline]
pub fn peer_addr(&self) -> Option<net::SocketAddr> {
self.head().peer_addr
// TODO! fix
// self.head().peer_addr
None
}
/// Get request's payload

View file

@ -6,7 +6,7 @@ use std::{
use h2::server::{self, Handshake};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::framed::State;
use crate::io::{DefaultFilter, Filter, Io, IoRef};
use crate::rt::net::TcpStream;
use crate::service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
use crate::time::{Millis, Seconds};
@ -19,20 +19,20 @@ use super::error::{DispatchError, ResponseError};
use super::helpers::DataFactory;
use super::request::Request;
use super::response::Response;
use super::{h1, h2::Dispatcher, Protocol};
//use super::{h1, h2::Dispatcher, Protocol};
use super::{h1, Protocol};
/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation
pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> {
pub struct HttpService<F, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<F>> {
srv: S,
cfg: ServiceConfig,
expect: X,
upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_request: cell::RefCell<Option<OnRequest<T>>>,
_t: marker::PhantomData<(T, B)>,
on_request: cell::RefCell<Option<OnRequest>>,
_t: marker::PhantomData<(F, B)>,
}
impl<T, S, B> HttpService<T, S, B>
impl<F, S, B> HttpService<F, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
@ -43,13 +43,14 @@ where
B: MessageBody + 'static,
{
/// Create builder for `HttpService` instance.
pub fn build() -> HttpServiceBuilder<T, S> {
pub fn build() -> HttpServiceBuilder<F, S> {
HttpServiceBuilder::new()
}
}
impl<T, S, B> HttpService<T, S, B>
impl<F, S, B> HttpService<F, S, B>
where
F: Filter,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
@ -59,7 +60,7 @@ where
B: MessageBody + 'static,
{
/// Create new `HttpService` instance.
pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self {
pub fn new<U: IntoServiceFactory<S>>(service: U) -> Self {
let cfg = ServiceConfig::new(
KeepAlive::Timeout(Seconds(5)),
Millis(5_000),
@ -73,31 +74,30 @@ where
srv: service.into_factory(),
expect: h1::ExpectHandler,
upgrade: None,
on_connect: None,
on_request: cell::RefCell::new(None),
_t: marker::PhantomData,
}
}
/// Create new `HttpService` instance with config.
pub(crate) fn with_config<F: IntoServiceFactory<S>>(
pub(crate) fn with_config<U: IntoServiceFactory<S>>(
cfg: ServiceConfig,
service: F,
service: U,
) -> Self {
HttpService {
cfg,
srv: service.into_factory(),
expect: h1::ExpectHandler,
upgrade: None,
on_connect: None,
on_request: cell::RefCell::new(None),
_t: marker::PhantomData,
}
}
}
impl<T, S, B, X, U> HttpService<T, S, B, X, U>
impl<F, S, B, X, U> HttpService<F, S, B, X, U>
where
F: Filter,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
@ -111,7 +111,7 @@ where
/// Service get called with request that contains `EXPECT` header.
/// Service must return request in case of success, in that case
/// request will be forwarded to main service.
pub fn expect<X1>(self, expect: X1) -> HttpService<T, S, B, X1, U>
pub fn expect<X1>(self, expect: X1) -> HttpService<F, S, B, X1, U>
where
X1: ServiceFactory<Config = (), Request = Request, Response = Request>,
X1::Error: ResponseError,
@ -123,7 +123,6 @@ where
cfg: self.cfg,
srv: self.srv,
upgrade: self.upgrade,
on_connect: self.on_connect,
on_request: self.on_request,
_t: marker::PhantomData,
}
@ -133,11 +132,11 @@ where
///
/// If service is provided then normal requests handling get halted
/// and this service get called with original request and framed object.
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, S, B, X, U1>
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<F, S, B, X, U1>
where
U1: ServiceFactory<
Config = (),
Request = (Request, T, State, h1::Codec),
Request = (Request, Io<F>, h1::Codec),
Response = (),
>,
U1::Error: fmt::Display + error::Error + 'static,
@ -149,77 +148,25 @@ where
cfg: self.cfg,
srv: self.srv,
expect: self.expect,
on_connect: self.on_connect,
on_request: self.on_request,
_t: marker::PhantomData,
}
}
/// Set on connect callback.
pub(crate) fn on_connect(
mut self,
f: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
) -> Self {
self.on_connect = f;
self
}
/// Set on request callback.
pub(crate) fn on_request(self, f: Option<OnRequest<T>>) -> Self {
pub(crate) fn on_request(self, f: Option<OnRequest>) -> Self {
*self.on_request.borrow_mut() = f;
self
}
}
impl<S, B, X, U> HttpService<TcpStream, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, TcpStream, State, h1::Codec),
Response = (),
>,
U::Error: fmt::Display + error::Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
<U::Service as Service>::Future: 'static,
{
/// Create simple tcp stream service
pub fn tcp(
self,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = DispatchError,
InitError = (),
> {
pipeline_factory(|io: TcpStream| async move {
let peer_addr = io.peer_addr().ok();
Ok((io, Protocol::Http1, peer_addr))
})
.and_then(self)
}
}
#[cfg(feature = "openssl")]
mod openssl {
use super::*;
use crate::server::openssl::{Acceptor, SslAcceptor, SslStream};
use crate::server::openssl::{Acceptor, SslAcceptor, SslFilter};
use crate::server::SslError;
impl<S, B, X, U> HttpService<SslStream<TcpStream>, S, B, X, U>
impl<S, B, X, U> HttpService<SslFilter<DefaultFilter>, S, B, X, U>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
@ -235,7 +182,7 @@ mod openssl {
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, SslStream<TcpStream>, State, h1::Codec),
Request = (Request, Io<SslFilter<DefaultFilter>>, h1::Codec),
Response = (),
>,
U::Error: fmt::Display + error::Error + 'static,
@ -249,7 +196,7 @@ mod openssl {
acceptor: SslAcceptor,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Request = Io<DefaultFilter>,
Response = (),
Error = SslError<DispatchError>,
InitError = (),
@ -260,19 +207,6 @@ mod openssl {
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(|io: SslStream<TcpStream>| async move {
let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() {
if protos.windows(2).any(|window| window == b"h2") {
Protocol::Http2
} else {
Protocol::Http1
}
} else {
Protocol::Http1
};
let peer_addr = io.get_ref().peer_addr().ok();
Ok((io, proto, peer_addr))
})
.and_then(self.map_err(SslError::Service))
}
}
@ -284,8 +218,9 @@ mod rustls {
use crate::server::rustls::{Acceptor, ServerConfig, TlsStream};
use crate::server::SslError;
impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
impl<F, S, B, X, U> HttpService<F, S, B, X, U>
where
F: Filter,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
@ -300,7 +235,7 @@ mod rustls {
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, TlsStream<TcpStream>, State, h1::Codec),
Request = (Request, Io<F>, h1::Codec),
Response = (),
>,
U::Error: fmt::Display + error::Error + 'static,
@ -311,47 +246,49 @@ mod rustls {
/// Create openssl based service
pub fn rustls(
self,
mut config: ServerConfig,
config: ServerConfig,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Request = Io<F>,
Response = (),
Error = SslError<DispatchError>,
//Error = SslError<DispatchError>,
Error = DispatchError,
InitError = (),
> {
let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()];
config.alpn_protocols = protos;
self
// let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()];
// config.alpn_protocols = protos;
pipeline_factory(
Acceptor::new(config)
.timeout(self.cfg.0.ssl_handshake_timeout)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!()),
)
.and_then(|io: TlsStream<TcpStream>| async move {
let proto = io
.get_ref()
.1
.alpn_protocol()
.and_then(|protos| {
if protos.windows(2).any(|window| window == b"h2") {
Some(Protocol::Http2)
} else {
None
}
})
.unwrap_or(Protocol::Http1);
let peer_addr = io.get_ref().0.peer_addr().ok();
Ok((io, proto, peer_addr))
})
.and_then(self.map_err(SslError::Service))
// pipeline_factory(
// Acceptor::new(config)
// .timeout(self.cfg.0.ssl_handshake_timeout)
// .map_err(SslError::Ssl)
// .map_init_err(|_| panic!()),
// )
// .and_then(|io: TlsStream<TcpStream>| async move {
// let proto = io
// .get_ref()
// .1
// .alpn_protocol()
// .and_then(|protos| {
// if protos.windows(2).any(|window| window == b"h2") {
// Some(Protocol::Http2)
// } else {
// None
// }
// })
// .unwrap_or(Protocol::Http1);
// let peer_addr = io.get_ref().0.peer_addr().ok();
// Ok((io, proto, peer_addr))
// })
// .and_then(self.map_err(SslError::Service))
}
}
}
impl<T, S, B, X, U> ServiceFactory for HttpService<T, S, B, X, U>
impl<F, S, B, X, U> ServiceFactory for HttpService<F, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
F: Filter + 'static,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::InitError: fmt::Debug,
@ -364,29 +301,24 @@ where
X::InitError: fmt::Debug,
X::Future: 'static,
<X::Service as Service>::Future: 'static,
U: ServiceFactory<
Config = (),
Request = (Request, T, State, h1::Codec),
Response = (),
>,
U: ServiceFactory<Config = (), Request = (Request, Io<F>, h1::Codec), Response = ()>,
U::Error: fmt::Display + error::Error + 'static,
U::InitError: fmt::Debug,
U::Future: 'static,
<U::Service as Service>::Future: 'static,
{
type Config = ();
type Request = (T, Protocol, Option<net::SocketAddr>);
type Request = Io<F>;
type Response = ();
type Error = DispatchError;
type InitError = ();
type Service = HttpServiceHandler<T, S::Service, B, X::Service, U::Service>;
type Service = HttpServiceHandler<F, S::Service, B, X::Service, U::Service>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Service, Self::InitError>>>>;
fn new_service(&self, _: ()) -> Self::Future {
let fut = self.srv.new_service(());
let fut_ex = self.expect.new_service(());
let fut_upg = self.upgrade.as_ref().map(|f| f.new_service(()));
let on_connect = self.on_connect.clone();
let on_request = self.on_request.borrow_mut().take();
let cfg = self.cfg.clone();
@ -414,7 +346,6 @@ where
Ok(HttpServiceHandler {
pool,
on_connect,
config: Rc::new(config),
_t: marker::PhantomData,
})
@ -423,16 +354,15 @@ where
}
/// `Service` implementation for http transport
pub struct HttpServiceHandler<T, S: Service, B, X: Service, U: Service> {
pub struct HttpServiceHandler<F, S: Service, B, X: Service, U: Service> {
pool: Pool,
config: Rc<DispatcherConfig<T, S, X, U>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: marker::PhantomData<(T, B, X)>,
config: Rc<DispatcherConfig<S, X, U>>,
_t: marker::PhantomData<(F, B, X)>,
}
impl<T, S, B, X, U> Service for HttpServiceHandler<T, S, B, X, U>
impl<F, S, B, X, U> Service for HttpServiceHandler<F, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
F: Filter + 'static,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Future: 'static,
@ -440,13 +370,13 @@ where
B: MessageBody + 'static,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
U: Service<Request = (Request, Io<F>, h1::Codec), Response = ()>,
U::Error: fmt::Display + error::Error + 'static,
{
type Request = (T, Protocol, Option<net::SocketAddr>);
type Request = Io<F>;
type Response = ();
type Error = DispatchError;
type Future = HttpServiceHandlerResponse<T, S, B, X, U>;
type Future = HttpServiceHandlerResponse<F, S, B, X, U>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let cfg = self.config.as_ref();
@ -507,46 +437,36 @@ where
}
}
fn call(&self, (io, proto, peer_addr): Self::Request) -> Self::Future {
log::trace!(
"New http connection protocol {:?} peer address {:?}",
proto,
peer_addr
);
let on_connect = self.on_connect.as_ref().map(|f| f(&io));
fn call(&self, io: Self::Request) -> Self::Future {
// log::trace!("New http connection protocol {:?}", proto);
match proto {
Protocol::Http2 => HttpServiceHandlerResponse {
state: ResponseState::H2Handshake {
data: Some((
server::Builder::new().handshake(io),
self.config.clone(),
on_connect,
peer_addr,
)),
},
},
Protocol::Http1 => HttpServiceHandlerResponse {
state: ResponseState::H1 {
fut: h1::Dispatcher::new(
io,
self.config.clone(),
peer_addr,
on_connect,
),
},
//match proto {
//Protocol::Http2 => todo!(),
// HttpServiceHandlerResponse {
// state: ResponseState::H2Handshake {
// data: Some((
// server::Builder::new().handshake(io),
// self.config.clone(),
// on_connect,
// peer_addr,
// )),
// },
// },
// Protocol::Http1 =>
HttpServiceHandlerResponse {
state: ResponseState::H1 {
fut: h1::Dispatcher::new(io, self.config.clone()),
},
// },
}
}
}
pin_project_lite::pin_project! {
pub struct HttpServiceHandlerResponse<T, S, B, X, U>
pub struct HttpServiceHandlerResponse<F, S, B, X, U>
where
T: AsyncRead,
T: AsyncWrite,
T: Unpin,
T: 'static,
F: Filter,
F: 'static,
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: 'static,
@ -557,52 +477,50 @@ pin_project_lite::pin_project! {
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: 'static,
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
U: Service<Request = (Request, Io<F>, h1::Codec), Response = ()>,
U::Error: fmt::Display,
U::Error: error::Error,
U::Error: 'static,
{
#[pin]
state: ResponseState<T, S, B, X, U>,
state: ResponseState<F, S, B, X, U>,
}
}
pin_project_lite::pin_project! {
#[project = StateProject]
enum ResponseState<T, S, B, X, U>
enum ResponseState<F, S, B, X, U>
where
S: Service<Request = Request>,
S::Error: ResponseError,
S::Error: 'static,
T: AsyncRead,
T: AsyncWrite,
T: Unpin,
T: 'static,
F: Filter,
F: 'static,
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError,
X::Error: 'static,
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
U: Service<Request = (Request, Io<F>, h1::Codec), Response = ()>,
U::Error: fmt::Display,
U::Error: error::Error,
U::Error: 'static,
{
H1 { #[pin] fut: h1::Dispatcher<T, S, B, X, U> },
H2 { fut: Dispatcher<T, S, B, X, U> },
H2Handshake { data:
Option<(
Handshake<T, Bytes>,
Rc<DispatcherConfig<T, S, X, U>>,
Option<Box<dyn DataFactory>>,
Option<net::SocketAddr>,
)>,
},
H1 { #[pin] fut: h1::Dispatcher<F, S, B, X, U> },
// H2 { fut: Dispatcher<F, S, B, X, U> },
// H2Handshake { data:
// Option<(
// Handshake<T, Bytes>,
// Rc<DispatcherConfig<S, X, U>>,
// Option<Box<dyn DataFactory>>,
// Option<net::SocketAddr>,
// )>,
// },
}
}
impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
impl<F, S, B, X, U> Future for HttpServiceHandlerResponse<F, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
F: Filter + 'static,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Future: 'static,
@ -610,7 +528,7 @@ where
B: MessageBody,
X: Service<Request = Request, Response = Request>,
X::Error: ResponseError + 'static,
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
U: Service<Request = (Request, Io<F>, h1::Codec), Response = ()>,
U::Error: fmt::Display + error::Error + 'static,
{
type Output = Result<(), DispatchError>;
@ -620,26 +538,26 @@ where
match this.state.project() {
StateProject::H1 { fut } => fut.poll(cx),
StateProject::H2 { ref mut fut } => Pin::new(fut).poll(cx),
StateProject::H2Handshake { data } => {
let conn = if let Some(ref mut item) = data {
match Pin::new(&mut item.0).poll(cx) {
Poll::Ready(Ok(conn)) => conn,
Poll::Ready(Err(err)) => {
trace!("H2 handshake error: {}", err);
return Poll::Ready(Err(err.into()));
}
Poll::Pending => return Poll::Pending,
}
} else {
panic!()
};
let (_, cfg, on_connect, peer_addr) = data.take().unwrap();
self.as_mut().project().state.set(ResponseState::H2 {
fut: Dispatcher::new(cfg, conn, on_connect, None, peer_addr),
});
self.poll(cx)
}
// StateProject::H2 { ref mut fut } => Pin::new(fut).poll(cx),
// StateProject::H2Handshake { data } => {
// let conn = if let Some(ref mut item) = data {
// match Pin::new(&mut item.0).poll(cx) {
// Poll::Ready(Ok(conn)) => conn,
// Poll::Ready(Err(err)) => {
// trace!("H2 handshake error: {}", err);
// return Poll::Ready(Err(err.into()));
// }
// Poll::Pending => return Poll::Pending,
// }
// } else {
// panic!()
// };
// let (_, cfg, on_connect, peer_addr) = data.take().unwrap();
// self.as_mut().project().state.set(ResponseState::H2 {
// fut: Dispatcher::new(cfg, conn, on_connect, None, peer_addr),
// });
// self.poll(cx)
// }
}
}
}

View file

@ -5,12 +5,15 @@ use std::{convert::TryFrom, io, net, str::FromStr, sync::mpsc, thread};
use coo_kie::{Cookie, CookieJar};
use crate::codec::{AsyncRead, AsyncWrite, Framed};
use crate::io::IoBoxed;
use crate::rt::{net::TcpStream, System};
use crate::server::{Server, StreamServiceFactory};
use crate::{time::Millis, time::Seconds, util::Bytes};
use super::client::error::WsClientError;
use super::client::{Client, ClientRequest, ClientResponse, Connector};
use super::client::{
ws::WsConnection, Client, ClientRequest, ClientResponse, Connector,
};
use super::error::{HttpError, PayloadError};
use super::header::{HeaderMap, HeaderName, HeaderValue};
use super::payload::Payload;
@ -207,7 +210,7 @@ fn parts(parts: &mut Option<Inner>) -> &mut Inner {
/// assert!(response.status().is_success());
/// }
/// ```
pub fn server<F: StreamServiceFactory<TcpStream>>(factory: F) -> TestServer {
pub fn server<F: StreamServiceFactory>(factory: F) -> TestServer {
let (tx, rx) = mpsc::channel();
// run server in separate thread
@ -318,21 +321,12 @@ impl TestServer {
}
/// Connect to websocket server at a given path
pub async fn ws_at(
&mut self,
path: &str,
) -> Result<Framed<impl AsyncRead + AsyncWrite, crate::ws::Codec>, WsClientError>
{
let url = self.url(path);
let connect = self.client.ws(url).connect();
connect.await.map(|ws| ws.into_inner().1)
pub async fn ws_at(&mut self, path: &str) -> Result<WsConnection, WsClientError> {
self.client.ws(self.url(path)).connect().await
}
/// Connect to a websocket server
pub async fn ws(
&mut self,
) -> Result<Framed<impl AsyncRead + AsyncWrite, crate::ws::Codec>, WsClientError>
{
pub async fn ws(&mut self) -> Result<WsConnection, WsClientError> {
self.ws_at("/").await
}

View file

@ -7,12 +7,12 @@
//! * `compress` - enables compression support in http and web modules
//! * `cookie` - enables cookie support in http and web modules
#![warn(
rust_2018_idioms,
unreachable_pub,
// missing_debug_implementations,
// missing_docs,
)]
//#![warn(
// rust_2018_idioms,
// unreachable_pub,
// missing_debug_implementations,
// missing_docs,
//)]
#![allow(
type_alias_bounds,
clippy::type_complexity,
@ -21,6 +21,7 @@
clippy::too_many_arguments,
clippy::new_without_default
)]
#![allow(unused_imports)]
#[macro_use]
extern crate log;
@ -35,7 +36,7 @@ pub(crate) use ntex_macros::rt_test2 as rt_test;
pub mod channel;
pub mod connect;
pub mod framed;
//pub mod framed;
#[cfg(feature = "http-framework")]
pub mod http;
pub mod server;

View file

@ -193,7 +193,7 @@ impl ServerBuilder {
factory: F,
) -> io::Result<Self>
where
F: StreamServiceFactory<TcpStream>,
F: StreamServiceFactory,
U: net::ToSocketAddrs,
{
let sockets = bind_addr(addr, self.backlog)?;
@ -219,7 +219,7 @@ impl ServerBuilder {
/// Add new unix domain service to the server.
pub fn bind_uds<F, U, N>(self, name: N, addr: U, factory: F) -> io::Result<Self>
where
F: StreamServiceFactory<crate::rt::net::UnixStream>,
F: StreamServiceFactory,
N: AsRef<str>,
U: AsRef<std::path::Path>,
{
@ -249,7 +249,7 @@ impl ServerBuilder {
factory: F,
) -> io::Result<Self>
where
F: StreamServiceFactory<crate::rt::net::UnixStream>,
F: StreamServiceFactory,
{
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
let token = self.token.next();
@ -273,7 +273,7 @@ impl ServerBuilder {
factory: F,
) -> io::Result<Self>
where
F: StreamServiceFactory<TcpStream>,
F: StreamServiceFactory,
{
let token = self.token.next();
self.services.push(Factory::create(

View file

@ -6,8 +6,8 @@ use std::{
use log::error;
use crate::rt::net::TcpStream;
use crate::service;
use crate::util::{counter::CounterGuard, HashMap, Ready};
use crate::{io::Io, service};
use super::builder::bind_addr;
use super::service::{
@ -199,7 +199,7 @@ impl InternalServiceFactory for ConfiguredService {
res.push((
token,
Box::new(StreamService::new(service::fn_service(
move |_: TcpStream| {
move |_: Io| {
error!("Service {:?} is not configured", name);
Ready::<_, ()>::Ok(())
},
@ -292,7 +292,7 @@ impl ServiceRuntime {
pub fn service<T, F>(&self, name: &str, service: F)
where
F: service::IntoServiceFactory<T>,
T: service::ServiceFactory<Config = (), Request = TcpStream> + 'static,
T: service::ServiceFactory<Config = (), Request = Io> + 'static,
T::Future: 'static,
T::Service: 'static,
T::InitError: fmt::Debug,
@ -338,7 +338,7 @@ struct ServiceFactory<T> {
impl<T> service::ServiceFactory for ServiceFactory<T>
where
T: service::ServiceFactory<Config = (), Request = TcpStream>,
T: service::ServiceFactory<Config = (), Request = Io>,
T::Future: 'static,
T::Service: 'static,
T::Error: 'static,

View file

@ -30,9 +30,6 @@ pub use self::config::{ServiceConfig, ServiceRuntime};
pub use self::service::StreamServiceFactory;
pub use self::test::{build_test_server, test_server, TestServer};
#[doc(hidden)]
pub use self::socket::FromStream;
#[non_exhaustive]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
/// Server readiness status

View file

@ -1,10 +1,13 @@
use std::task::{Context, Poll};
use std::{error::Error, fmt, future::Future, io, marker::PhantomData, pin::Pin};
pub use ntex_openssl::SslFilter;
pub use open_ssl::ssl::{self, AlpnError, Ssl, SslAcceptor, SslAcceptorBuilder};
pub use tokio_openssl::SslStream;
use ntex_openssl::SslAcceptor as IoSslAcceptor;
use crate::codec::{AsyncRead, AsyncWrite};
use crate::io::{Filter, FilterFactory, Io};
use crate::service::{Service, ServiceFactory};
use crate::time::{sleep, Millis, Sleep};
use crate::util::{counter::Counter, counter::CounterGuard, Ready};
@ -14,19 +17,17 @@ use super::MAX_SSL_ACCEPT_COUNTER;
/// Support `TLS` server connections via openssl package
///
/// `openssl` feature enables `Acceptor` type
pub struct Acceptor<T: AsyncRead + AsyncWrite> {
acceptor: SslAcceptor,
timeout: Millis,
io: PhantomData<T>,
pub struct Acceptor<F> {
acceptor: IoSslAcceptor,
_t: PhantomData<F>,
}
impl<T: AsyncRead + AsyncWrite> Acceptor<T> {
impl<F> Acceptor<F> {
/// Create default openssl acceptor service
pub fn new(acceptor: SslAcceptor) -> Self {
Acceptor {
acceptor,
timeout: Millis(5_000),
io: PhantomData,
acceptor: IoSslAcceptor::new(acceptor),
_t: PhantomData,
}
}
@ -34,30 +35,26 @@ impl<T: AsyncRead + AsyncWrite> Acceptor<T> {
///
/// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
self.timeout = timeout.into();
self.acceptor.timeout(timeout);
self
}
}
impl<T: AsyncRead + AsyncWrite> Clone for Acceptor<T> {
impl<F> Clone for Acceptor<F> {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
timeout: self.timeout,
io: PhantomData,
_t: PhantomData,
}
}
}
impl<T> ServiceFactory for Acceptor<T>
where
T: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
{
type Request = T;
type Response = SslStream<T>;
impl<F: Filter + 'static> ServiceFactory for Acceptor<F> {
type Request = Io<F>;
type Response = Io<SslFilter<F>>;
type Error = Box<dyn Error>;
type Config = ();
type Service = AcceptorService<T>;
type Service = AcceptorService<F>;
type InitError = ();
type Future = Ready<Self::Service, Self::InitError>;
@ -66,28 +63,23 @@ where
Ready::Ok(AcceptorService {
acceptor: self.acceptor.clone(),
conns: conns.priv_clone(),
timeout: self.timeout,
io: PhantomData,
_t: PhantomData,
})
})
}
}
pub struct AcceptorService<T> {
acceptor: SslAcceptor,
pub struct AcceptorService<F> {
acceptor: IoSslAcceptor,
conns: Counter,
timeout: Millis,
io: PhantomData<T>,
_t: PhantomData<F>,
}
impl<T> Service for AcceptorService<T>
where
T: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
{
type Request = T;
type Response = SslStream<T>;
impl<F: Filter + 'static> Service for AcceptorService<F> {
type Request = Io<F>;
type Response = Io<SslFilter<F>>;
type Error = Box<dyn Error>;
type Future = AcceptorServiceResponse<T>;
type Future = AcceptorServiceResponse<F>;
#[inline]
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -100,57 +92,29 @@ where
#[inline]
fn call(&self, req: Self::Request) -> Self::Future {
let ssl = Ssl::new(self.acceptor.context())
.expect("Provided SSL acceptor was invalid.");
AcceptorServiceResponse {
_guard: self.conns.get(),
io: None,
delay: self.timeout.map(sleep),
io_factory: Some(SslStream::new(ssl, req)),
fut: self.acceptor.clone().create(req),
}
}
}
pub struct AcceptorServiceResponse<T>
where
T: AsyncRead,
T: AsyncWrite,
{
io: Option<SslStream<T>>,
delay: Option<Sleep>,
io_factory: Option<Result<SslStream<T>, open_ssl::error::ErrorStack>>,
_guard: CounterGuard,
}
impl<T: AsyncRead + AsyncWrite + Unpin> Future for AcceptorServiceResponse<T> {
type Output = Result<SslStream<T>, Box<dyn Error>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut();
if let Some(ref delay) = this.delay {
match delay.poll_elapsed(cx) {
Poll::Pending => (),
Poll::Ready(_) => {
return Poll::Ready(Err(Box::new(io::Error::new(
io::ErrorKind::TimedOut,
"ssl handshake timeout",
))))
}
}
}
match this.io_factory.take() {
Some(Ok(io)) => this.io = Some(io),
Some(Err(err)) => return Poll::Ready(Err(Box::new(err))),
None => (),
}
let io = this.io.as_mut().unwrap();
match Pin::new(io).poll_accept(cx) {
Poll::Ready(Ok(_)) => Poll::Ready(Ok(this.io.take().unwrap())),
Poll::Ready(Err(e)) => Poll::Ready(Err(Box::new(e))),
Poll::Pending => Poll::Pending,
}
pin_project_lite::pin_project! {
pub struct AcceptorServiceResponse<F>
where
F: Filter,
F: 'static,
{
#[pin]
fut: <IoSslAcceptor as FilterFactory<F>>::Future,
_guard: CounterGuard,
}
}
impl<F: Filter + 'static> Future for AcceptorServiceResponse<F> {
type Output = Result<Io<SslFilter<F>>, Box<dyn Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().fut.poll(cx)
}
}

View file

@ -1,3 +1,4 @@
use std::convert::TryInto;
use std::{
future::Future, marker::PhantomData, net::SocketAddr, pin::Pin, task::Context,
task::Poll,
@ -5,12 +6,12 @@ use std::{
use log::error;
use crate::io::Io;
use crate::service::{Service, ServiceFactory};
use crate::util::{counter::CounterGuard, Ready};
use crate::{rt::spawn, time::Millis};
use super::socket::{FromStream, Stream};
use super::Token;
use super::{socket::Stream, Token};
/// Server message
pub(super) enum ServerMessage {
@ -22,8 +23,8 @@ pub(super) enum ServerMessage {
ForceShutdown,
}
pub trait StreamServiceFactory<Stream: FromStream>: Send + Clone + 'static {
type Factory: ServiceFactory<Config = (), Request = Stream>;
pub trait StreamServiceFactory: Send + Clone + 'static {
type Factory: ServiceFactory<Config = (), Request = Io>;
fn create(&self) -> Self::Factory;
}
@ -57,12 +58,11 @@ impl<T> StreamService<T> {
}
}
impl<T, I> Service for StreamService<T>
impl<T> Service for StreamService<T>
where
T: Service<Request = I>,
T: Service<Request = Io>,
T::Future: 'static,
T::Error: 'static,
I: FromStream,
{
type Request = (Option<CounterGuard>, ServerMessage);
type Response = ();
@ -82,7 +82,7 @@ where
fn call(&self, (guard, req): (Option<CounterGuard>, ServerMessage)) -> Self::Future {
match req {
ServerMessage::Connect(stream) => {
let stream = FromStream::from_stream(stream).map_err(|e| {
let stream = stream.try_into().map_err(|e| {
error!("Cannot convert to an async io stream: {}", e);
});
@ -102,18 +102,16 @@ where
}
}
pub(super) struct Factory<F: StreamServiceFactory<Io>, Io: FromStream> {
pub(super) struct Factory<F: StreamServiceFactory> {
name: String,
inner: F,
token: Token,
addr: SocketAddr,
_t: PhantomData<Io>,
}
impl<F, Io> Factory<F, Io>
impl<F> Factory<F>
where
F: StreamServiceFactory<Io>,
Io: FromStream + Send + 'static,
F: StreamServiceFactory,
{
pub(crate) fn create(
name: String,
@ -126,15 +124,13 @@ where
token,
inner,
addr,
_t: PhantomData,
})
}
}
impl<F, Io> InternalServiceFactory for Factory<F, Io>
impl<F> InternalServiceFactory for Factory<F>
where
F: StreamServiceFactory<Io>,
Io: FromStream + Send + 'static,
F: StreamServiceFactory,
{
fn name(&self, _: Token) -> &str {
&self.name
@ -146,7 +142,6 @@ where
inner: self.inner.clone(),
token: self.token,
addr: self.addr,
_t: PhantomData,
})
}
@ -187,11 +182,10 @@ impl InternalServiceFactory for Box<dyn InternalServiceFactory> {
}
}
impl<F, T, I> StreamServiceFactory<I> for F
impl<F, T> StreamServiceFactory for F
where
F: Fn() -> T + Send + Clone + 'static,
T: ServiceFactory<Config = (), Request = I>,
I: FromStream,
T: ServiceFactory<Config = (), Request = Io>,
{
type Factory = T;

View file

@ -1,6 +1,7 @@
use std::{fmt, io, net};
use std::{convert::TryFrom, fmt, io, net};
use crate::codec::{AsyncRead, AsyncWrite};
use crate::io::{Io, IoStream};
use crate::rt::net::TcpStream;
pub(crate) enum Listener {
@ -146,32 +147,29 @@ pub enum Stream {
Uds(mio::net::UnixStream),
}
pub trait FromStream: AsyncRead + AsyncWrite + Sized {
fn from_stream(stream: Stream) -> io::Result<Self>;
}
impl TryFrom<Stream> for Io {
type Error = io::Error;
#[cfg(unix)]
impl FromStream for TcpStream {
fn from_stream(sock: Stream) -> io::Result<Self> {
fn try_from(sock: Stream) -> Result<Self, Self::Error> {
#[cfg(unix)]
match sock {
Stream::Tcp(stream) => {
use std::os::unix::io::{FromRawFd, IntoRawFd};
let fd = IntoRawFd::into_raw_fd(stream);
let io = TcpStream::from_std(unsafe { FromRawFd::from_raw_fd(fd) })?;
io.set_nodelay(true)?;
Ok(io)
Ok(Io::new(io))
}
#[cfg(unix)]
Stream::Uds(_) => {
panic!("Should not happen, bug in server impl");
Stream::Uds(stream) => {
use crate::rt::net::UnixStream;
use std::os::unix::io::{FromRawFd, IntoRawFd};
let fd = IntoRawFd::into_raw_fd(stream);
let ud = UnixStream::from_std(unsafe { FromRawFd::from_raw_fd(fd) });
todo!()
}
}
}
}
#[cfg(windows)]
impl FromStream for TcpStream {
fn from_stream(sock: Stream) -> io::Result<Self> {
#[cfg(windows)]
match sock {
Stream::Tcp(stream) => {
use std::os::windows::io::{FromRawSocket, IntoRawSocket};
@ -179,26 +177,7 @@ impl FromStream for TcpStream {
let io =
TcpStream::from_std(unsafe { FromRawSocket::from_raw_socket(fd) })?;
io.set_nodelay(true)?;
Ok(io)
}
#[cfg(unix)]
Stream::Uds(_) => {
panic!("Should not happen, bug in server impl");
}
}
}
}
#[cfg(unix)]
impl FromStream for crate::rt::net::UnixStream {
fn from_stream(sock: Stream) -> io::Result<Self> {
match sock {
Stream::Tcp(_) => panic!("Should not happen, bug in server impl"),
Stream::Uds(stream) => {
use crate::rt::net::UnixStream;
use std::os::unix::io::{FromRawFd, IntoRawFd};
let fd = IntoRawFd::into_raw_fd(stream);
UnixStream::from_std(unsafe { FromRawFd::from_raw_fd(fd) })
Ok(Io::new(io))
}
}
}

View file

@ -37,7 +37,7 @@ use crate::server::{Server, ServerBuilder, StreamServiceFactory};
/// assert!(response.status().is_success());
/// }
/// ```
pub fn test_server<F: StreamServiceFactory<TcpStream>>(factory: F) -> TestServer {
pub fn test_server<F: StreamServiceFactory>(factory: F) -> TestServer {
let (tx, rx) = mpsc::channel();
// run server in separate thread

View file

@ -3,6 +3,7 @@ use std::{cell::Ref, cell::RefCell, cell::RefMut, fmt, net, rc::Rc};
use crate::http::{
HeaderMap, HttpMessage, Message, Method, Payload, RequestHead, Uri, Version,
};
use crate::io::IoRef;
use crate::router::Path;
use crate::util::{Extensions, Ready};
@ -105,6 +106,12 @@ impl HttpRequest {
}
}
/// Io reference for current connection
#[inline]
pub fn io(&self) -> Option<&IoRef> {
self.head().io.as_ref()
}
/// Get a reference to the Path parameters.
///
/// Params is a container for url parameters.
@ -183,17 +190,6 @@ impl HttpRequest {
&self.0.rmap
}
/// Peer socket address
///
/// Peer address is actual socket address, if proxy is used in front of
/// ntex http server, then peer address would be address of this proxy.
///
/// To get client connection information `.connection_info()` should be used.
#[inline]
pub fn peer_addr(&self) -> Option<net::SocketAddr> {
self.head().peer_addr
}
/// Get *ConnectionInfo* for the current request.
///
/// This method panics if request's extensions container is already

View file

@ -119,7 +119,9 @@ impl ConnectionInfo {
}
if remote.is_none() {
// get peeraddr from socketaddr
peer = req.peer_addr.map(|addr| format!("{}", addr));
// TODO! fix
// peer = req.peer_addr.map(|addr| format!("{}", addr));
}
}

View file

@ -6,6 +6,7 @@ use std::{fmt, net};
use crate::http::{
header, HeaderMap, HttpMessage, Method, Payload, RequestHead, Response, Uri, Version,
};
use crate::io::IoRef;
use crate::router::{Path, Resource};
use crate::util::Extensions;
@ -87,6 +88,12 @@ impl<Err> WebRequest<Err> {
WebResponse::new(res.into(), self.req)
}
/// Io reference for current connection
#[inline]
pub fn io(&self) -> Option<&IoRef> {
self.head().io.as_ref()
}
/// This method returns reference to the request head
#[inline]
pub fn head(&self) -> &RequestHead {
@ -147,17 +154,6 @@ impl<Err> WebRequest<Err> {
}
}
/// Peer socket address
///
/// Peer address is actual socket address, if proxy is used in front of
/// ntex http server, then peer address would be address of this proxy.
///
/// To get client connection information `ConnectionInfo` should be used.
#[inline]
pub fn peer_addr(&self) -> Option<net::SocketAddr> {
self.head().peer_addr
}
/// Get *ConnectionInfo* for the current request.
#[inline]
pub fn connection_info(&self) -> Ref<'_, ConnectionInfo> {

View file

@ -2,8 +2,8 @@ use std::{fmt, io, marker::PhantomData, net, sync::Arc, sync::Mutex};
#[cfg(feature = "openssl")]
use crate::server::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder};
#[cfg(feature = "rustls")]
use crate::server::rustls::ServerConfig as RustlsServerConfig;
//#[cfg(feature = "rustls")]
//use crate::server::rustls::ServerConfig as RustlsServerConfig;
#[cfg(unix)]
use crate::http::Protocol;
@ -275,7 +275,6 @@ where
.disconnect_timeout(c.client_disconnect)
.memory_pool(c.pool)
.finish(map_config(factory(), move |_| cfg.clone()))
.tcp()
},
)?;
Ok(self)
@ -326,50 +325,50 @@ where
Ok(self)
}
#[cfg(feature = "rustls")]
/// Use listener for accepting incoming tls connection requests
///
/// This method sets alpn protocols to "h2" and "http/1.1"
pub fn listen_rustls(
self,
lst: net::TcpListener,
config: RustlsServerConfig,
) -> io::Result<Self> {
self.listen_rustls_inner(lst, config)
}
// #[cfg(feature = "rustls")]
// /// Use listener for accepting incoming tls connection requests
// ///
// /// This method sets alpn protocols to "h2" and "http/1.1"
// pub fn listen_rustls(
// self,
// lst: net::TcpListener,
// config: RustlsServerConfig,
// ) -> io::Result<Self> {
// self.listen_rustls_inner(lst, config)
// }
#[cfg(feature = "rustls")]
fn listen_rustls_inner(
mut self,
lst: net::TcpListener,
config: RustlsServerConfig,
) -> io::Result<Self> {
let factory = self.factory.clone();
let cfg = self.config.clone();
let addr = lst.local_addr().unwrap();
// #[cfg(feature = "rustls")]
// fn listen_rustls_inner(
// mut self,
// lst: net::TcpListener,
// config: RustlsServerConfig,
// ) -> io::Result<Self> {
// let factory = self.factory.clone();
// let cfg = self.config.clone();
// let addr = lst.local_addr().unwrap();
self.builder = self.builder.listen(
format!("ntex-web-rustls-service-{}", addr),
lst,
move || {
let c = cfg.lock().unwrap();
let cfg = AppConfig::new(
true,
addr,
c.host.clone().unwrap_or_else(|| format!("{}", addr)),
);
HttpService::build()
.keep_alive(c.keep_alive)
.client_timeout(c.client_timeout)
.disconnect_timeout(c.client_disconnect)
.ssl_handshake_timeout(c.handshake_timeout)
.memory_pool(c.pool)
.finish(map_config(factory(), move |_| cfg.clone()))
.rustls(config.clone())
},
)?;
Ok(self)
}
// self.builder = self.builder.listen(
// format!("ntex-web-rustls-service-{}", addr),
// lst,
// move || {
// let c = cfg.lock().unwrap();
// let cfg = AppConfig::new(
// true,
// addr,
// c.host.clone().unwrap_or_else(|| format!("{}", addr)),
// );
// HttpService::build()
// .keep_alive(c.keep_alive)
// .client_timeout(c.client_timeout)
// .disconnect_timeout(c.client_disconnect)
// .ssl_handshake_timeout(c.handshake_timeout)
// .memory_pool(c.pool)
// .finish(map_config(factory(), move |_| cfg.clone()))
// .rustls(config.clone())
// },
// )?;
// Ok(self)
// }
/// The socket address to bind
///
@ -437,21 +436,21 @@ where
Ok(self)
}
#[cfg(feature = "rustls")]
/// Start listening for incoming tls connections.
///
/// This method sets alpn protocols to "h2" and "http/1.1"
pub fn bind_rustls<A: net::ToSocketAddrs>(
mut self,
addr: A,
config: RustlsServerConfig,
) -> io::Result<Self> {
let sockets = self.bind2(addr)?;
for lst in sockets {
self = self.listen_rustls_inner(lst, config.clone())?;
}
Ok(self)
}
// #[cfg(feature = "rustls")]
// /// Start listening for incoming tls connections.
// ///
// /// This method sets alpn protocols to "h2" and "http/1.1"
// pub fn bind_rustls<A: net::ToSocketAddrs>(
// mut self,
// addr: A,
// config: RustlsServerConfig,
// ) -> io::Result<Self> {
// let sockets = self.bind2(addr)?;
// for lst in sockets {
// self = self.listen_rustls_inner(lst, config.clone())?;
// }
// Ok(self)
// }
#[cfg(unix)]
/// Start listening for unix domain connections on existing listener.
@ -479,16 +478,11 @@ where
socket_addr,
c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)),
);
pipeline_factory(|io: UnixStream| {
crate::util::Ready::Ok((io, Protocol::Http1, None))
})
.and_then(
HttpService::build()
.keep_alive(c.keep_alive)
.client_timeout(c.client_timeout)
.memory_pool(c.pool)
.finish(map_config(factory(), move |_| config.clone())),
)
HttpService::build()
.keep_alive(c.keep_alive)
.client_timeout(c.client_timeout)
.memory_pool(c.pool)
.finish(map_config(factory(), move |_| config.clone()))
})?;
Ok(self)
}
@ -520,16 +514,11 @@ where
socket_addr,
c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)),
);
pipeline_factory(|io: UnixStream| {
crate::util::Ready::Ok((io, Protocol::Http1, None))
})
.and_then(
HttpService::build()
.keep_alive(c.keep_alive)
.client_timeout(c.client_timeout)
.memory_pool(c.pool)
.finish(map_config(factory(), move |_| config.clone())),
)
HttpService::build()
.keep_alive(c.keep_alive)
.client_timeout(c.client_timeout)
.memory_pool(c.pool)
.finish(map_config(factory(), move |_| config.clone()))
},
)?;
Ok(self)

View file

@ -465,15 +465,12 @@ impl TestRequest {
/// Complete request creation and generate `Request` instance
pub fn to_request(mut self) -> Request {
let mut req = self.req.finish();
req.head_mut().peer_addr = self.peer_addr;
req
self.req.finish()
}
/// Complete request creation and generate `WebRequest` instance
pub fn to_srv_request(mut self) -> WebRequest<DefaultError> {
let (mut head, payload) = self.req.finish().into_parts();
head.peer_addr = self.peer_addr;
let (head, payload) = self.req.finish().into_parts();
*self.path.get_mut() = head.uri.clone();
WebRequest::new(HttpRequest::new(
@ -494,8 +491,7 @@ impl TestRequest {
/// Complete request creation and generate `HttpRequest` instance
pub fn to_http_request(mut self) -> HttpRequest {
let (mut head, payload) = self.req.finish().into_parts();
head.peer_addr = self.peer_addr;
let (head, payload) = self.req.finish().into_parts();
*self.path.get_mut() = head.uri.clone();
HttpRequest::new(
@ -511,8 +507,7 @@ impl TestRequest {
/// Complete request creation and generate `HttpRequest` and `Payload` instances
pub fn to_http_parts(mut self) -> (HttpRequest, Payload) {
let (mut head, payload) = self.req.finish().into_parts();
head.peer_addr = self.peer_addr;
let (head, payload) = self.req.finish().into_parts();
*self.path.get_mut() = head.uri.clone();
let req = HttpRequest::new(
@ -636,7 +631,6 @@ where
HttpService::build()
.client_timeout(ctimeout)
.h1(map_config(factory(), move |_| cfg.clone()))
.tcp()
}),
HttpVer::Http2 => builder.listen("test", tcp, move || {
let cfg =
@ -644,7 +638,6 @@ where
HttpService::build()
.client_timeout(ctimeout)
.h2(map_config(factory(), move |_| cfg.clone()))
.tcp()
}),
HttpVer::Both => builder.listen("test", tcp, move || {
let cfg =
@ -652,7 +645,6 @@ where
HttpService::build()
.client_timeout(ctimeout)
.finish(map_config(factory(), move |_| cfg.clone()))
.tcp()
}),
},
#[cfg(feature = "openssl")]
@ -842,8 +834,9 @@ impl TestServerConfig {
/// Start rustls server
#[cfg(feature = "rustls")]
pub fn rustls(mut self, config: rust_tls::ServerConfig) -> Self {
self.stream = StreamType::Rustls(config);
self
// self.stream = StreamType::Rustls(config);
// self
unimplemented!()
}
/// Set server client timeout in seconds for first request.
@ -928,19 +921,12 @@ impl TestServer {
}
/// Connect to websocket server at a given path
pub async fn ws_at(
&self,
path: &str,
) -> Result<ws::WsConnection<impl AsyncRead + AsyncWrite>, WsClientError> {
let url = self.url(path);
let connect = self.client.ws(url).connect();
connect.await
pub async fn ws_at(&self, path: &str) -> Result<ws::WsConnection, WsClientError> {
self.client.ws(self.url(path)).connect().await
}
/// Connect to a websocket server
pub async fn ws(
&self,
) -> Result<ws::WsConnection<impl AsyncRead + AsyncWrite>, WsClientError> {
pub async fn ws(&self) -> Result<ws::WsConnection, WsClientError> {
self.ws_at("/").await
}

View file

@ -25,6 +25,7 @@ pub use self::stream::{StreamDecoder, StreamEncoder};
pub enum WsError<E> {
Service(E),
KeepAlive,
Disconnected,
Protocol(ProtocolError),
Io(io::Error),
}

View file

@ -1,18 +1,18 @@
use std::{future::Future, rc::Rc};
use crate::framed::{OnDisconnect, State};
use crate::io::{Io, IoRef, OnDisconnect};
use crate::ws;
pub struct WsSink(Rc<WsSinkInner>);
struct WsSinkInner {
state: State,
io: IoRef,
codec: ws::Codec,
}
impl WsSink {
pub(crate) fn new(state: State, codec: ws::Codec) -> Self {
Self(Rc::new(WsSinkInner { state, codec }))
pub(crate) fn new(io: IoRef, codec: ws::Codec) -> Self {
Self(Rc::new(WsSinkInner { io, codec }))
}
/// Endcode and send message to the peer.
@ -23,13 +23,13 @@ impl WsSink {
let inner = self.0.clone();
async move {
inner.state.write().encode(item, &inner.codec)?;
inner.io.write().encode(item, &inner.codec)?;
Ok(())
}
}
/// Notify when connection get disconnected
pub fn on_disconnect(&self) -> OnDisconnect {
self.0.state.on_disconnect()
self.0.io.on_disconnect()
}
}

View file

@ -215,15 +215,12 @@ async fn test_connection_reuse() {
num2.fetch_add(1, Ordering::Relaxed);
ok(io)
})
.and_then(
HttpService::new(map_config(
App::new().service(
web::resource("/").route(web::to(|| async { HttpResponse::Ok() })),
),
|_| AppConfig::default(),
))
.tcp(),
)
.and_then(HttpService::new(map_config(
App::new().service(
web::resource("/").route(web::to(|| async { HttpResponse::Ok() })),
),
|_| AppConfig::default(),
)))
});
let client = Client::build().timeout(Seconds(10)).finish();
@ -253,15 +250,12 @@ async fn test_connection_force_close() {
num2.fetch_add(1, Ordering::Relaxed);
ok(io)
})
.and_then(
HttpService::new(map_config(
App::new().service(
web::resource("/").route(web::to(|| async { HttpResponse::Ok() })),
),
|_| AppConfig::default(),
))
.tcp(),
)
.and_then(HttpService::new(map_config(
App::new().service(
web::resource("/").route(web::to(|| async { HttpResponse::Ok() })),
),
|_| AppConfig::default(),
)))
});
let client = Client::build().timeout(Seconds(10)).finish();
@ -291,15 +285,12 @@ async fn test_connection_server_close() {
num2.fetch_add(1, Ordering::Relaxed);
ok(io)
})
.and_then(
HttpService::new(map_config(
App::new().service(web::resource("/").route(web::to(|| async {
HttpResponse::Ok().force_close().finish()
}))),
|_| AppConfig::default(),
))
.tcp(),
)
.and_then(HttpService::new(map_config(
App::new().service(web::resource("/").route(web::to(|| async {
HttpResponse::Ok().force_close().finish()
}))),
|_| AppConfig::default(),
)))
});
let client = Client::build().timeout(Seconds(10)).finish();
@ -329,16 +320,13 @@ async fn test_connection_wait_queue() {
num2.fetch_add(1, Ordering::Relaxed);
ok(io)
})
.and_then(
HttpService::new(map_config(
App::new().service(
web::resource("/")
.route(web::to(|| async { HttpResponse::Ok().body(STR) })),
),
|_| AppConfig::default(),
))
.tcp(),
)
.and_then(HttpService::new(map_config(
App::new().service(
web::resource("/")
.route(web::to(|| async { HttpResponse::Ok().body(STR) })),
),
|_| AppConfig::default(),
)))
});
let client = Client::build()
@ -378,15 +366,12 @@ async fn test_connection_wait_queue_force_close() {
num2.fetch_add(1, Ordering::Relaxed);
ok(io)
})
.and_then(
HttpService::new(map_config(
App::new().service(web::resource("/").route(web::to(|| async {
HttpResponse::Ok().force_close().body(STR)
}))),
|_| AppConfig::default(),
))
.tcp(),
)
.and_then(HttpService::new(map_config(
App::new().service(web::resource("/").route(web::to(|| async {
HttpResponse::Ok().force_close().body(STR)
}))),
|_| AppConfig::default(),
)))
});
let client = Client::build()

View file

@ -8,8 +8,8 @@ use ntex::http::test::server as test_server;
use ntex::http::{
body, header, HttpService, KeepAlive, Method, Request, Response, StatusCode,
};
use ntex::time::{sleep, Millis};
use ntex::{service::fn_service, time::Seconds, util::Bytes, web::error};
use ntex::time::{sleep, Millis, Seconds};
use ntex::{service::fn_service, util::Bytes, util::Ready, web::error};
#[ntex::test]
async fn test_h1() {
@ -22,7 +22,6 @@ async fn test_h1() {
assert!(req.peer_addr().is_some());
future::ok::<_, io::Error>(Response::Ok().finish())
})
.tcp()
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
@ -41,7 +40,6 @@ async fn test_h1_2() {
assert_eq!(req.version(), http::Version::HTTP_11);
future::ok::<_, io::Error>(Response::Ok().finish())
})
.tcp()
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
@ -72,7 +70,6 @@ async fn test_expect_continue() {
let _ = req.payload().next().await;
Ok::<_, io::Error>(Response::Ok().finish())
}))
.tcp()
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@ -101,22 +98,18 @@ async fn test_chunked_payload() {
let total_size: usize = chunk_sizes.iter().sum();
let srv = test_server(|| {
HttpService::build()
.h1(fn_service(|mut request: Request| {
request
.take_payload()
.map(|res| match res {
Ok(pl) => pl,
Err(e) => panic!("Error reading payload: {}", e),
})
.fold(0usize, |acc, chunk| ready(acc + chunk.len()))
.map(|req_size| {
Ok::<_, io::Error>(
Response::Ok().body(format!("size={}", req_size)),
)
})
}))
.tcp()
HttpService::build().h1(fn_service(|mut request: Request| {
request
.take_payload()
.map(|res| match res {
Ok(pl) => pl,
Err(e) => panic!("Error reading payload: {}", e),
})
.fold(0usize, |acc, chunk| ready(acc + chunk.len()))
.map(|req_size| {
Ok::<_, io::Error>(Response::Ok().body(format!("size={}", req_size)))
})
}))
});
let returned_size = {
@ -156,7 +149,6 @@ async fn test_slow_request() {
HttpService::build()
.client_timeout(Seconds(1))
.finish(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
.tcp()
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@ -169,9 +161,7 @@ async fn test_slow_request() {
#[ntex::test]
async fn test_http1_malformed_request() {
let srv = test_server(|| {
HttpService::build()
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
.tcp()
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@ -184,9 +174,7 @@ async fn test_http1_malformed_request() {
#[ntex::test]
async fn test_http1_keepalive() {
let srv = test_server(|| {
HttpService::build()
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
.tcp()
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@ -207,7 +195,6 @@ async fn test_http1_keepalive_timeout() {
HttpService::build()
.keep_alive(1)
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
.tcp()
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@ -225,9 +212,7 @@ async fn test_http1_keepalive_timeout() {
#[ntex::test]
async fn test_http1_keepalive_close() {
let srv = test_server(|| {
HttpService::build()
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
.tcp()
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@ -245,9 +230,7 @@ async fn test_http1_keepalive_close() {
#[ntex::test]
async fn test_http10_keepalive_default_close() {
let srv = test_server(|| {
HttpService::build()
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
.tcp()
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@ -264,9 +247,7 @@ async fn test_http10_keepalive_default_close() {
#[ntex::test]
async fn test_http10_keepalive() {
let srv = test_server(|| {
HttpService::build()
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
.tcp()
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@ -293,7 +274,6 @@ async fn test_http1_keepalive_disabled() {
HttpService::build()
.keep_alive(KeepAlive::Disabled)
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
.tcp()
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
@ -315,20 +295,18 @@ async fn test_content_length() {
};
let srv = test_server(|| {
HttpService::build()
.h1(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, io::Error>(Response::new(statuses[indx]))
})
.tcp()
HttpService::build().h1(|req: Request| {
let indx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
StatusCode::SWITCHING_PROTOCOLS,
StatusCode::PROCESSING,
StatusCode::OK,
StatusCode::NOT_FOUND,
];
future::ok::<_, io::Error>(Response::new(statuses[indx]))
})
});
let header = HeaderName::from_static("content-length");
@ -362,7 +340,7 @@ async fn test_h1_headers() {
let data = data.clone();
HttpService::build().h1(move |_| {
let mut builder = Response::Ok();
for idx in 0..90 {
for idx in 0..20 {
builder.header(
format!("X-TEST-{}", idx).as_str(),
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
@ -380,8 +358,9 @@ async fn test_h1_headers() {
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
);
}
future::ok::<_, io::Error>(builder.body(data.clone()))
}).tcp()
println!("SENDING body");
Ready::Ok::<_, io::Error>(builder.body(data.clone()))
})
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
@ -417,9 +396,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
#[ntex::test]
async fn test_h1_body() {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
.tcp()
HttpService::build().h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
@ -433,9 +410,7 @@ async fn test_h1_body() {
#[ntex::test]
async fn test_h1_head_empty() {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
.tcp()
HttpService::build().h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
});
let response = srv.request(http::Method::HEAD, "/").send().await.unwrap();
@ -457,13 +432,9 @@ async fn test_h1_head_empty() {
#[ntex::test]
async fn test_h1_head_binary() {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| {
ok::<_, io::Error>(
Response::Ok().content_length(STR.len() as u64).body(STR),
)
})
.tcp()
HttpService::build().h1(|_| {
ok::<_, io::Error>(Response::Ok().content_length(STR.len() as u64).body(STR))
})
});
let response = srv.request(http::Method::HEAD, "/").send().await.unwrap();
@ -485,9 +456,7 @@ async fn test_h1_head_binary() {
#[ntex::test]
async fn test_h1_head_binary2() {
let srv = test_server(|| {
HttpService::build()
.h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
.tcp()
HttpService::build().h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
});
let response = srv.request(http::Method::HEAD, "/").send().await.unwrap();
@ -505,14 +474,12 @@ async fn test_h1_head_binary2() {
#[ntex::test]
async fn test_h1_body_length() {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, io::Error>(
Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)),
)
})
.tcp()
HttpService::build().h1(|_| {
let body = once(ok(Bytes::from_static(STR.as_ref())));
ok::<_, io::Error>(
Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)),
)
})
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
@ -525,18 +492,15 @@ async fn test_h1_body_length() {
#[ntex::test]
async fn test_h1_body_chunked_explicit() {
env_logger::init();
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| {
let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref())));
ok::<_, io::Error>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
.tcp()
HttpService::build().h1(|_| {
let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref())));
ok::<_, io::Error>(
Response::Ok()
.header(header::TRANSFER_ENCODING, "chunked")
.streaming(body),
)
})
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
@ -561,12 +525,10 @@ async fn test_h1_body_chunked_explicit() {
#[ntex::test]
async fn test_h1_body_chunked_implicit() {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| {
let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref())));
ok::<_, io::Error>(Response::Ok().streaming(body))
})
.tcp()
HttpService::build().h1(|_| {
let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref())));
ok::<_, io::Error>(Response::Ok().streaming(body))
})
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
@ -589,16 +551,14 @@ async fn test_h1_body_chunked_implicit() {
#[ntex::test]
async fn test_h1_response_http_error_handling() {
let mut srv = test_server(|| {
HttpService::build()
.h1(fn_service(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, io::Error>(
Response::Ok()
.header(http::header::CONTENT_TYPE, &broken_header[..])
.body(STR),
)
}))
.tcp()
HttpService::build().h1(fn_service(|_| {
let broken_header = Bytes::from_static(b"\0\0\0");
ok::<_, io::Error>(
Response::Ok()
.header(http::header::CONTENT_TYPE, &broken_header[..])
.body(STR),
)
}))
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
@ -612,14 +572,12 @@ async fn test_h1_response_http_error_handling() {
#[ntex::test]
async fn test_h1_service_error() {
let mut srv = test_server(|| {
HttpService::build()
.h1(|_| {
future::err::<Response, _>(error::InternalError::default(
"error",
StatusCode::BAD_REQUEST,
))
})
.tcp()
HttpService::build().h1(|_| {
future::err::<Response, _>(error::InternalError::default(
"error",
StatusCode::BAD_REQUEST,
))
})
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
@ -629,19 +587,3 @@ async fn test_h1_service_error() {
let bytes = srv.load_body(response).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"error"));
}
#[ntex::test]
async fn test_h1_on_connect() {
let srv = test_server(|| {
HttpService::build()
.on_connect(|_| 10usize)
.h1(|req: Request| {
assert!(req.extensions().contains::<usize>());
future::ok::<_, io::Error>(Response::Ok().finish())
})
.tcp()
});
let response = srv.request(Method::GET, "/").send().await.unwrap();
assert!(response.status().is_success());
}

View file

@ -3,10 +3,9 @@ use std::sync::{mpsc, Arc};
use std::{io, io::Read, net, thread, time};
use futures::future::{lazy, ok, FutureExt};
use futures::SinkExt;
use ntex::codec::{BytesCodec, Framed};
use ntex::rt::net::TcpStream;
use ntex::codec::BytesCodec;
use ntex::io::Io;
use ntex::server::{Server, TestServer};
use ntex::service::fn_service;
use ntex::util::{Bytes, Ready};
@ -77,9 +76,10 @@ fn test_start() {
.backlog(100)
.disable_signals()
.bind("test", addr, move || {
fn_service(|io: TcpStream| async move {
let mut f = Framed::new(io, BytesCodec);
f.send(Bytes::from_static(b"test")).await.unwrap();
fn_service(|io: Io| async move {
io.send(Bytes::from_static(b"test"), &BytesCodec)
.await
.unwrap();
Ok::<_, ()>(())
})
})