mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-04 13:27:39 +03:00
use ntex-io instead of framed
This commit is contained in:
parent
dafd339817
commit
3dbba47ab1
62 changed files with 1545 additions and 5639 deletions
|
@ -1,8 +1,14 @@
|
|||
# Changes
|
||||
|
||||
## [0.6.0] - 2021-12-xx
|
||||
|
||||
* Removed Framed type
|
||||
|
||||
* Removed tokio dependency
|
||||
|
||||
## [0.5.1] - 2021-09-08
|
||||
|
||||
* Fix tight loop in Framed::close() method.
|
||||
* Fix tight loop in Framed::close() method
|
||||
|
||||
## [0.5.0] - 2021-06-27
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex-codec"
|
||||
version = "0.5.1"
|
||||
version = "0.6.0"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "Utilities for encoding and decoding frames"
|
||||
keywords = ["network", "framework", "async", "futures"]
|
||||
|
@ -16,12 +16,4 @@ name = "ntex_codec"
|
|||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
bitflags = "1.3"
|
||||
ntex-bytes = "0.1"
|
||||
ntex-util = "0.1"
|
||||
log = "0.4"
|
||||
tokio = { version = "1", default-features = false }
|
||||
|
||||
[dev-dependencies]
|
||||
ntex = "0.4.13"
|
||||
futures = "0.3.13"
|
||||
|
|
|
@ -1,691 +0,0 @@
|
|||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{fmt, io};
|
||||
|
||||
use ntex_bytes::{Buf, BytesMut};
|
||||
use ntex_util::{future::Either, ready, Sink, Stream};
|
||||
|
||||
use crate::{AsyncRead, AsyncWrite, Decoder, Encoder};
|
||||
|
||||
const LW: usize = 1024;
|
||||
const HW: usize = 8 * 1024;
|
||||
|
||||
bitflags::bitflags! {
|
||||
struct Flags: u8 {
|
||||
const EOF = 0b0001;
|
||||
const READABLE = 0b0010;
|
||||
const DISCONNECTED = 0b0100;
|
||||
const SHUTDOWN = 0b1000;
|
||||
}
|
||||
}
|
||||
|
||||
/// A unified interface to an underlying I/O object, using
|
||||
/// the `Encoder` and `Decoder` traits to encode and decode frames.
|
||||
/// `Framed` is heavily optimized for streaming io.
|
||||
pub struct Framed<T, U> {
|
||||
io: T,
|
||||
codec: U,
|
||||
flags: Flags,
|
||||
read_buf: BytesMut,
|
||||
write_buf: BytesMut,
|
||||
err: Option<io::Error>,
|
||||
}
|
||||
|
||||
impl<T, U> Framed<T, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite,
|
||||
U: Decoder + Encoder,
|
||||
{
|
||||
#[inline]
|
||||
/// Provides an interface for reading and writing to
|
||||
/// `Io` object, using `Decode` and `Encode` traits of codec.
|
||||
///
|
||||
/// Raw I/O objects work with byte sequences, but higher-level code usually
|
||||
/// wants to batch these into meaningful chunks, called "frames". This
|
||||
/// method layers framing on top of an I/O object, by using the `Codec`
|
||||
/// traits to handle encoding and decoding of messages frames. Note that
|
||||
/// the incoming and outgoing frame types may be distinct.
|
||||
pub fn new(io: T, codec: U) -> Framed<T, U> {
|
||||
Framed {
|
||||
io,
|
||||
codec,
|
||||
err: None,
|
||||
flags: Flags::empty(),
|
||||
read_buf: BytesMut::with_capacity(HW),
|
||||
write_buf: BytesMut::with_capacity(HW),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Framed<T, U> {
|
||||
#[inline]
|
||||
/// Construct `Framed` object `parts`.
|
||||
pub fn from_parts(parts: FramedParts<T, U>) -> Framed<T, U> {
|
||||
Framed {
|
||||
io: parts.io,
|
||||
codec: parts.codec,
|
||||
flags: parts.flags,
|
||||
write_buf: parts.write_buf,
|
||||
read_buf: parts.read_buf,
|
||||
err: parts.err,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Returns a reference to the underlying codec.
|
||||
pub fn get_codec(&self) -> &U {
|
||||
&self.codec
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Returns a mutable reference to the underlying codec.
|
||||
pub fn get_codec_mut(&mut self) -> &mut U {
|
||||
&mut self.codec
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Returns a reference to the underlying I/O stream wrapped by `Framed`.
|
||||
///
|
||||
/// Note that care should be taken to not tamper with the underlying stream
|
||||
/// of data coming in as it may corrupt the stream of frames otherwise
|
||||
/// being worked with.
|
||||
pub fn get_ref(&self) -> &T {
|
||||
&self.io
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Returns a mutable reference to the underlying I/O stream wrapped by
|
||||
/// `Framed`.
|
||||
///
|
||||
/// Note that care should be taken to not tamper with the underlying stream
|
||||
/// of data coming in as it may corrupt the stream of frames otherwise
|
||||
/// being worked with.
|
||||
pub fn get_mut(&mut self) -> &mut T {
|
||||
&mut self.io
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Get read buffer.
|
||||
pub fn read_buf(&mut self) -> &mut BytesMut {
|
||||
&mut self.read_buf
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Get write buffer.
|
||||
pub fn write_buf(&mut self) -> &mut BytesMut {
|
||||
&mut self.write_buf
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if write buffer is empty.
|
||||
pub fn is_write_buf_empty(&self) -> bool {
|
||||
self.write_buf.is_empty()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if write buffer is full.
|
||||
pub fn is_write_buf_full(&self) -> bool {
|
||||
self.write_buf.len() >= HW
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if framed object is closed
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.flags.contains(Flags::DISCONNECTED)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Consume the `Frame`, returning `Frame` with different codec.
|
||||
pub fn into_framed<U2>(self, codec: U2) -> Framed<T, U2> {
|
||||
Framed {
|
||||
codec,
|
||||
io: self.io,
|
||||
flags: self.flags,
|
||||
read_buf: self.read_buf,
|
||||
write_buf: self.write_buf,
|
||||
err: self.err,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Consume the `Frame`, returning `Frame` with different io.
|
||||
pub fn map_io<F, T2>(self, f: F) -> Framed<T2, U>
|
||||
where
|
||||
F: Fn(T) -> T2,
|
||||
{
|
||||
Framed {
|
||||
io: f(self.io),
|
||||
codec: self.codec,
|
||||
flags: self.flags,
|
||||
read_buf: self.read_buf,
|
||||
write_buf: self.write_buf,
|
||||
err: self.err,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Consume the `Frame`, returning `Frame` with different codec.
|
||||
pub fn map_codec<F, U2>(self, f: F) -> Framed<T, U2>
|
||||
where
|
||||
F: Fn(U) -> U2,
|
||||
{
|
||||
Framed {
|
||||
io: self.io,
|
||||
codec: f(self.codec),
|
||||
flags: self.flags,
|
||||
read_buf: self.read_buf,
|
||||
write_buf: self.write_buf,
|
||||
err: self.err,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Consumes the `Frame`, returning its underlying I/O stream, the buffer
|
||||
/// with unprocessed data, and the codec.
|
||||
///
|
||||
/// Note that care should be taken to not tamper with the underlying stream
|
||||
/// of data coming in as it may corrupt the stream of frames otherwise
|
||||
/// being worked with.
|
||||
pub fn into_parts(self) -> FramedParts<T, U> {
|
||||
FramedParts {
|
||||
io: self.io,
|
||||
codec: self.codec,
|
||||
flags: self.flags,
|
||||
read_buf: self.read_buf,
|
||||
write_buf: self.write_buf,
|
||||
err: self.err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Framed<T, U>
|
||||
where
|
||||
T: AsyncWrite + Unpin,
|
||||
U: Encoder,
|
||||
{
|
||||
#[inline]
|
||||
/// Serialize item and Write to the inner buffer
|
||||
pub fn write(
|
||||
&mut self,
|
||||
item: <U as Encoder>::Item,
|
||||
) -> Result<(), <U as Encoder>::Error> {
|
||||
let remaining = self.write_buf.capacity() - self.write_buf.len();
|
||||
if remaining < LW {
|
||||
self.write_buf.reserve(HW - remaining);
|
||||
}
|
||||
|
||||
self.codec.encode(item, &mut self.write_buf)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if framed is able to write more data.
|
||||
///
|
||||
/// `Framed` object considers ready if there is free space in write buffer.
|
||||
pub fn is_write_ready(&self) -> bool {
|
||||
self.write_buf.len() < HW
|
||||
}
|
||||
|
||||
/// Flush write buffer to underlying I/O stream.
|
||||
pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
log::trace!("flushing framed transport");
|
||||
|
||||
let len = self.write_buf.len();
|
||||
if len != 0 {
|
||||
let mut written = 0;
|
||||
while written < len {
|
||||
match Pin::new(&mut self.io).poll_write(cx, &self.write_buf[written..]) {
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!(
|
||||
"Disconnected during flush, written {}",
|
||||
written
|
||||
);
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"failed to write frame to transport",
|
||||
)));
|
||||
} else {
|
||||
written += n
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
log::trace!("Error during flush: {}", e);
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
log::trace!("flushed {} bytes", written);
|
||||
|
||||
// remove written data
|
||||
if written == len {
|
||||
self.write_buf.clear()
|
||||
} else {
|
||||
self.write_buf.advance(written);
|
||||
}
|
||||
}
|
||||
|
||||
// flush
|
||||
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
|
||||
|
||||
if self.write_buf.is_empty() {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Framed<T, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
#[inline]
|
||||
/// Flush write buffer and shutdown underlying I/O stream.
|
||||
///
|
||||
/// Close method shutdown write side of a io object and
|
||||
/// then reads until disconnect or error, high level code must use
|
||||
/// timeout for close operation.
|
||||
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
if !self.flags.contains(Flags::DISCONNECTED) {
|
||||
// flush write buffer
|
||||
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
|
||||
|
||||
if !self.flags.contains(Flags::SHUTDOWN) {
|
||||
// shutdown WRITE side
|
||||
ready!(Pin::new(&mut self.io).poll_shutdown(cx)).map_err(|e| {
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
e
|
||||
})?;
|
||||
self.flags.insert(Flags::SHUTDOWN);
|
||||
}
|
||||
|
||||
// read until 0 or err
|
||||
let mut buf = [0u8; 512];
|
||||
loop {
|
||||
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
|
||||
match ready!(Pin::new(&mut self.io).poll_read(cx, &mut read_buf)) {
|
||||
Err(_) | Ok(_) if read_buf.filled().is_empty() => {
|
||||
break;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
self.flags.insert(Flags::DISCONNECTED);
|
||||
}
|
||||
log::trace!("framed transport flushed and closed");
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
pub type ItemType<U> =
|
||||
Result<<U as Decoder>::Item, Either<<U as Decoder>::Error, io::Error>>;
|
||||
|
||||
impl<T, U> Framed<T, U>
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
U: Decoder,
|
||||
{
|
||||
/// Try to read underlying I/O stream and decode item.
|
||||
pub fn next_item(&mut self, cx: &mut Context<'_>) -> Poll<Option<ItemType<U>>> {
|
||||
let mut done_read = false;
|
||||
|
||||
loop {
|
||||
// Repeatedly call `decode` or `decode_eof` as long as it is
|
||||
// "readable". Readable is defined as not having returned `None`. If
|
||||
// the upstream has returned EOF, and the decoder is no longer
|
||||
// readable, it can be assumed that the decoder will never become
|
||||
// readable again, at which point the stream is terminated.
|
||||
|
||||
if self.flags.contains(Flags::READABLE) {
|
||||
if self.flags.contains(Flags::EOF) {
|
||||
return match self.codec.decode_eof(&mut self.read_buf) {
|
||||
Ok(Some(frame)) => Poll::Ready(Some(Ok(frame))),
|
||||
Ok(None) => {
|
||||
if let Some(err) = self.err.take() {
|
||||
Poll::Ready(Some(Err(Either::Right(err))))
|
||||
} else if !self.read_buf.is_empty() {
|
||||
Poll::Ready(Some(Err(Either::Right(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"bytes remaining on stream",
|
||||
)))))
|
||||
} else {
|
||||
Poll::Ready(None)
|
||||
}
|
||||
}
|
||||
Err(e) => return Poll::Ready(Some(Err(Either::Left(e)))),
|
||||
};
|
||||
}
|
||||
|
||||
log::trace!("attempting to decode a frame");
|
||||
|
||||
match self.codec.decode(&mut self.read_buf) {
|
||||
Ok(Some(frame)) => {
|
||||
log::trace!("frame decoded from buffer");
|
||||
return Poll::Ready(Some(Ok(frame)));
|
||||
}
|
||||
Err(e) => return Poll::Ready(Some(Err(Either::Left(e)))),
|
||||
_ => (), // Need more data
|
||||
}
|
||||
|
||||
self.flags.remove(Flags::READABLE);
|
||||
if done_read {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
debug_assert!(!self.flags.contains(Flags::EOF));
|
||||
|
||||
// read all data from socket
|
||||
let mut updated = false;
|
||||
loop {
|
||||
// Otherwise, try to read more data and try again. Make sure we've got room
|
||||
let remaining = self.read_buf.capacity() - self.read_buf.len();
|
||||
if remaining < LW {
|
||||
self.read_buf.reserve(HW - remaining)
|
||||
}
|
||||
match crate::poll_read_buf(
|
||||
Pin::new(&mut self.io),
|
||||
cx,
|
||||
&mut self.read_buf,
|
||||
) {
|
||||
Poll::Pending => {
|
||||
if updated {
|
||||
done_read = true;
|
||||
self.flags.insert(Flags::READABLE);
|
||||
break;
|
||||
} else {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
self.flags.insert(Flags::EOF | Flags::READABLE);
|
||||
if updated {
|
||||
done_read = true;
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
updated = true;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
if updated {
|
||||
done_read = true;
|
||||
self.err = Some(e);
|
||||
self.flags.insert(Flags::EOF | Flags::READABLE);
|
||||
break;
|
||||
} else {
|
||||
return Poll::Ready(Some(Err(Either::Right(e))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Stream for Framed<T, U>
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
U: Decoder + Unpin,
|
||||
{
|
||||
type Item = Result<U::Item, Either<U::Error, io::Error>>;
|
||||
|
||||
#[inline]
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
self.next_item(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Sink<U::Item> for Framed<T, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
U: Encoder + Unpin,
|
||||
{
|
||||
type Error = Either<U::Error, io::Error>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(
|
||||
self: Pin<&mut Self>,
|
||||
_: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
if self.is_write_ready() {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn start_send(
|
||||
mut self: Pin<&mut Self>,
|
||||
item: <U as Encoder>::Item,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.write(item).map_err(Either::Left)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.flush(cx).map_err(Either::Right)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_close(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), Self::Error>> {
|
||||
self.close(cx).map_err(Either::Right)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> fmt::Debug for Framed<T, U>
|
||||
where
|
||||
T: fmt::Debug,
|
||||
U: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Framed")
|
||||
.field("io", &self.io)
|
||||
.field("codec", &self.codec)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// `FramedParts` contains an export of the data of a Framed transport.
|
||||
/// It can be used to construct a new `Framed` with a different codec.
|
||||
/// It contains all current buffers and the inner transport.
|
||||
#[derive(Debug)]
|
||||
pub struct FramedParts<T, U> {
|
||||
/// The inner transport used to read bytes to and write bytes to
|
||||
pub io: T,
|
||||
|
||||
/// The codec
|
||||
pub codec: U,
|
||||
|
||||
/// The buffer with read but unprocessed data.
|
||||
pub read_buf: BytesMut,
|
||||
|
||||
/// A buffer with unprocessed data which are not written yet.
|
||||
pub write_buf: BytesMut,
|
||||
|
||||
flags: Flags,
|
||||
err: Option<io::Error>,
|
||||
}
|
||||
|
||||
impl<T, U> FramedParts<T, U> {
|
||||
/// Create a new, default, `FramedParts`
|
||||
pub fn new(io: T, codec: U) -> FramedParts<T, U> {
|
||||
FramedParts {
|
||||
io,
|
||||
codec,
|
||||
err: None,
|
||||
flags: Flags::empty(),
|
||||
read_buf: BytesMut::new(),
|
||||
write_buf: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new `FramedParts` with read buffer
|
||||
pub fn with_read_buf(io: T, codec: U, read_buf: BytesMut) -> FramedParts<T, U> {
|
||||
FramedParts {
|
||||
io,
|
||||
codec,
|
||||
read_buf,
|
||||
err: None,
|
||||
flags: Flags::empty(),
|
||||
write_buf: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::{future::lazy, Sink};
|
||||
use ntex::testing::Io;
|
||||
use ntex_bytes::Bytes;
|
||||
|
||||
use super::*;
|
||||
use crate::BytesCodec;
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_basics() {
|
||||
let (_, server) = Io::create();
|
||||
let mut server = Framed::new(server, BytesCodec);
|
||||
server.get_codec_mut();
|
||||
server.get_ref();
|
||||
server.get_mut();
|
||||
|
||||
let parts = server.into_parts();
|
||||
let server = Framed::from_parts(FramedParts::new(parts.io, parts.codec));
|
||||
assert!(format!("{:?}", server).contains("Framed"));
|
||||
}
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_sink() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
let mut server = Framed::new(server, BytesCodec);
|
||||
|
||||
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
|
||||
.await
|
||||
.is_ready());
|
||||
|
||||
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
|
||||
Pin::new(&mut server).start_send(data).unwrap();
|
||||
assert_eq!(client.read_any(), b"".as_ref());
|
||||
assert_eq!(server.read_buf(), b"".as_ref());
|
||||
assert_eq!(server.write_buf(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());
|
||||
|
||||
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
|
||||
.await
|
||||
.is_ready());
|
||||
assert_eq!(client.read_any(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());
|
||||
|
||||
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
|
||||
.await
|
||||
.is_pending());
|
||||
client.close().await;
|
||||
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
|
||||
.await
|
||||
.is_ready());
|
||||
assert!(client.is_closed());
|
||||
}
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_write_pending() {
|
||||
let (client, server) = Io::create();
|
||||
let mut server = Framed::new(server, BytesCodec);
|
||||
|
||||
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
|
||||
.await
|
||||
.is_ready());
|
||||
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
|
||||
Pin::new(&mut server).start_send(data).unwrap();
|
||||
|
||||
client.remote_buffer_cap(3);
|
||||
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
|
||||
.await
|
||||
.is_pending());
|
||||
assert_eq!(client.read_any(), b"GET".as_ref());
|
||||
|
||||
client.remote_buffer_cap(1024);
|
||||
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
|
||||
.await
|
||||
.is_ready());
|
||||
assert_eq!(client.read_any(), b" /test HTTP/1.1\r\n\r\n".as_ref());
|
||||
|
||||
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
|
||||
.await
|
||||
.is_pending());
|
||||
client.close().await;
|
||||
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
|
||||
.await
|
||||
.is_ready());
|
||||
assert!(client.is_closed());
|
||||
assert!(server.is_closed());
|
||||
}
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_read_pending() {
|
||||
let (client, server) = Io::create();
|
||||
let mut server = Framed::new(server, BytesCodec);
|
||||
|
||||
client.read_pending();
|
||||
assert!(lazy(|cx| Pin::new(&mut server).next_item(cx))
|
||||
.await
|
||||
.is_pending());
|
||||
|
||||
client.write(b"GET /test HTTP/1.1\r\n\r\n");
|
||||
client.close().await;
|
||||
|
||||
let item = lazy(|cx| Pin::new(&mut server).next_item(cx))
|
||||
.await
|
||||
.map(|i| i.unwrap().unwrap().freeze());
|
||||
assert_eq!(
|
||||
item,
|
||||
Poll::Ready(Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"))
|
||||
);
|
||||
let item = lazy(|cx| Pin::new(&mut server).next_item(cx))
|
||||
.await
|
||||
.map(|i| i.is_none());
|
||||
assert_eq!(item, Poll::Ready(true));
|
||||
}
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_read_error() {
|
||||
let (client, server) = Io::create();
|
||||
let mut server = Framed::new(server, BytesCodec);
|
||||
|
||||
client.read_pending();
|
||||
assert!(lazy(|cx| Pin::new(&mut server).next_item(cx))
|
||||
.await
|
||||
.is_pending());
|
||||
|
||||
client.write(b"GET /test HTTP/1.1\r\n\r\n");
|
||||
client.read_error(io::Error::new(io::ErrorKind::Other, "error"));
|
||||
|
||||
let item = lazy(|cx| Pin::new(&mut server).next_item(cx))
|
||||
.await
|
||||
.map(|i| i.unwrap().unwrap().freeze());
|
||||
assert_eq!(
|
||||
item,
|
||||
Poll::Ready(Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"))
|
||||
);
|
||||
assert_eq!(
|
||||
lazy(|cx| Pin::new(&mut server).next_item(cx))
|
||||
.await
|
||||
.map(|i| i.unwrap().is_err()),
|
||||
Poll::Ready(true)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,55 +1,10 @@
|
|||
//! Utilities for encoding and decoding frames.
|
||||
//!
|
||||
//! Contains adapters to go from streams of bytes, [`AsyncRead`] and
|
||||
//! [`AsyncWrite`], to framed streams implementing `Sink` and `Stream`.
|
||||
//! Framed streams are also known as `transports`.
|
||||
//!
|
||||
//! [`AsyncRead`]: #
|
||||
//! [`AsyncWrite`]: #
|
||||
#![deny(rust_2018_idioms, warnings)]
|
||||
use std::{io, mem::MaybeUninit, pin::Pin, task::Context, task::Poll};
|
||||
//! Utilities for encoding and decoding frames.
|
||||
|
||||
mod bcodec;
|
||||
mod decoder;
|
||||
mod encoder;
|
||||
mod framed;
|
||||
|
||||
pub use self::bcodec::BytesCodec;
|
||||
pub use self::decoder::Decoder;
|
||||
pub use self::encoder::Encoder;
|
||||
pub use self::framed::{Framed, FramedParts};
|
||||
|
||||
pub use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use ntex_bytes::{BufMut, BytesMut};
|
||||
|
||||
pub fn poll_read_buf<T: AsyncRead>(
|
||||
io: Pin<&mut T>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut BytesMut,
|
||||
) -> Poll<io::Result<usize>> {
|
||||
if !buf.has_remaining_mut() {
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
|
||||
let n = {
|
||||
let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [MaybeUninit<u8>]) };
|
||||
let mut buf = ReadBuf::uninit(dst);
|
||||
let ptr = buf.filled().as_ptr();
|
||||
if io.poll_read(cx, &mut buf)?.is_pending() {
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
// Ensure the pointer does not change from under us
|
||||
assert_eq!(ptr, buf.filled().as_ptr());
|
||||
buf.filled().len()
|
||||
};
|
||||
|
||||
// Safety: This is guaranteed to be the number of initialized (and read)
|
||||
// bytes due to the invariants provided by `ReadBuf::filled`.
|
||||
unsafe {
|
||||
buf.advance_mut(n);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
//! Framed transport dispatcher
|
||||
use std::{
|
||||
cell::Cell, future::Future, io, pin::Pin, rc::Rc, task::Context, task::Poll, time,
|
||||
cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll, time,
|
||||
};
|
||||
|
||||
use ntex_bytes::Pool;
|
||||
|
@ -70,7 +70,6 @@ enum DispatcherError<S, U> {
|
|||
KeepAlive,
|
||||
Encoder(U),
|
||||
Service(S),
|
||||
Io(io::Error),
|
||||
}
|
||||
|
||||
enum PollService<U: Encoder + Decoder> {
|
||||
|
@ -157,7 +156,7 @@ where
|
|||
///
|
||||
/// By default disconnect timeout is set to 1 seconds.
|
||||
pub fn disconnect_timeout(self, val: Seconds) -> Self {
|
||||
self.inner.state.set_disconnect_timeout(val);
|
||||
self.inner.state.set_disconnect_timeout(val.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
@ -176,12 +175,7 @@ where
|
|||
Ok(Some(val)) => match write.encode(val, &self.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => write.enable_backpressure(None),
|
||||
Err(Either::Left(err)) => {
|
||||
self.error.set(Some(DispatcherError::Encoder(err)))
|
||||
}
|
||||
Err(Either::Right(err)) => {
|
||||
self.error.set(Some(DispatcherError::Io(err)))
|
||||
}
|
||||
Err(err) => self.error.set(Some(DispatcherError::Encoder(err))),
|
||||
},
|
||||
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
|
||||
Ok(None) => return,
|
||||
|
@ -407,12 +401,7 @@ where
|
|||
Ok(Some(item)) => match write.encode(item, &self.shared.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => write.enable_backpressure(None),
|
||||
Err(Either::Left(err)) => {
|
||||
self.shared.error.set(Some(DispatcherError::Encoder(err)))
|
||||
}
|
||||
Err(Either::Right(err)) => {
|
||||
self.shared.error.set(Some(DispatcherError::Io(err)))
|
||||
}
|
||||
Err(err) => self.shared.error.set(Some(DispatcherError::Encoder(err))),
|
||||
},
|
||||
Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))),
|
||||
Ok(None) => (),
|
||||
|
@ -443,9 +432,6 @@ where
|
|||
DispatcherError::Encoder(err) => {
|
||||
PollService::Item(DispatchItem::EncoderError(err))
|
||||
}
|
||||
DispatcherError::Io(err) => {
|
||||
PollService::Item(DispatchItem::Disconnect(Some(err)))
|
||||
}
|
||||
DispatcherError::Service(err) => {
|
||||
self.error.set(Some(err));
|
||||
PollService::ServiceError
|
||||
|
|
|
@ -46,13 +46,7 @@ impl ReadFilter for DefaultFilter {
|
|||
|
||||
#[inline]
|
||||
fn read_closed(&self, err: Option<io::Error>) {
|
||||
if err.is_some() {
|
||||
self.0.error.set(err);
|
||||
}
|
||||
self.0.write_task.wake();
|
||||
self.0.dispatch_task.wake();
|
||||
self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
|
||||
self.0.notify_disconnect();
|
||||
self.0.set_error(err);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -109,13 +103,9 @@ impl WriteFilter for DefaultFilter {
|
|||
|
||||
#[inline]
|
||||
fn write_closed(&self, err: Option<io::Error>) {
|
||||
if err.is_some() {
|
||||
self.0.error.set(err);
|
||||
}
|
||||
self.0.read_task.wake();
|
||||
self.0.set_error(err);
|
||||
self.0.insert_flags(Flags::IO_CLOSED);
|
||||
self.0.dispatch_task.wake();
|
||||
self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
|
||||
self.0.notify_disconnect();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::{fmt, future::Future, io, task::Context, task::Poll};
|
||||
use std::{any::Any, any::TypeId, fmt, future::Future, io, task::Context, task::Poll};
|
||||
|
||||
pub mod testing;
|
||||
|
||||
|
@ -14,12 +14,12 @@ mod tokio_impl;
|
|||
|
||||
use ntex_bytes::BytesMut;
|
||||
use ntex_codec::{Decoder, Encoder};
|
||||
use ntex_util::time::Millis;
|
||||
use ntex_util::{channel::oneshot::Receiver, future::Either, time::Millis};
|
||||
|
||||
pub use self::dispatcher::Dispatcher;
|
||||
pub use self::filter::DefaultFilter;
|
||||
pub use self::state::{Io, IoRef, ReadRef, WriteRef};
|
||||
pub use self::tasks::{ReadState, WriteState};
|
||||
pub use self::state::{Io, IoRef, OnDisconnect, ReadRef, WriteRef};
|
||||
pub use self::tasks::{ReadContext, WriteContext};
|
||||
pub use self::time::Timer;
|
||||
|
||||
pub use self::utils::{filter_factory, from_iostream, into_boxed, into_io};
|
||||
|
@ -55,8 +55,15 @@ pub trait WriteFilter {
|
|||
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>;
|
||||
}
|
||||
|
||||
pub trait Filter: ReadFilter + WriteFilter {
|
||||
pub trait Filter: ReadFilter + WriteFilter + 'static {
|
||||
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>>;
|
||||
|
||||
fn query(
|
||||
&self,
|
||||
id: TypeId,
|
||||
) -> Either<Option<Box<dyn Any>>, Receiver<Option<Box<dyn Any>>>> {
|
||||
Either::Left(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FilterFactory<F: Filter>: Sized {
|
||||
|
@ -69,7 +76,7 @@ pub trait FilterFactory<F: Filter>: Sized {
|
|||
}
|
||||
|
||||
pub trait IoStream {
|
||||
fn start(self, _: ReadState, _: WriteState);
|
||||
fn start(self, _: ReadContext, _: WriteContext);
|
||||
}
|
||||
|
||||
/// Framed transport item
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use std::cell::{Cell, RefCell};
|
||||
use std::task::{Context, Poll};
|
||||
use std::{future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
|
||||
use std::{fmt, future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
|
||||
|
||||
use ntex_bytes::{BytesMut, PoolId, PoolRef};
|
||||
use ntex_codec::{Decoder, Encoder};
|
||||
use ntex_util::time::{Millis, Seconds};
|
||||
use ntex_util::time::Millis;
|
||||
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker};
|
||||
|
||||
use super::filter::{DefaultFilter, NullFilter};
|
||||
use super::tasks::{ReadState, WriteState};
|
||||
use super::tasks::{ReadContext, WriteContext};
|
||||
use super::{Filter, FilterFactory, IoStream};
|
||||
|
||||
bitflags::bitflags! {
|
||||
|
@ -21,13 +21,15 @@ bitflags::bitflags! {
|
|||
const IO_FILTERS_TO = 0b0000_0000_0000_0100;
|
||||
/// shutdown io tasks
|
||||
const IO_SHUTDOWN = 0b0000_0000_0000_1000;
|
||||
/// io object is closed
|
||||
const IO_CLOSED = 0b0000_0000_0001_0000;
|
||||
|
||||
/// pause io read
|
||||
const RD_PAUSED = 0b0000_0000_0000_1000;
|
||||
const RD_PAUSED = 0b0000_0000_0010_0000;
|
||||
/// new data is available
|
||||
const RD_READY = 0b0000_0000_0001_0000;
|
||||
const RD_READY = 0b0000_0000_0100_0000;
|
||||
/// read buffer is full
|
||||
const RD_BUF_FULL = 0b0000_0000_0010_0000;
|
||||
const RD_BUF_FULL = 0b0000_0000_1000_0000;
|
||||
|
||||
/// wait write completion
|
||||
const WR_WAIT = 0b0000_0001_0000_0000;
|
||||
|
@ -103,8 +105,22 @@ impl IoStateInner {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn is_io_err(&self) -> bool {
|
||||
self.flags.get().contains(Flags::IO_ERR)
|
||||
fn is_io_open(&self) -> bool {
|
||||
!self.flags.get().intersects(
|
||||
Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_SHUTDOWN | Flags::IO_CLOSED,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) fn set_error(&self, err: Option<io::Error>) {
|
||||
if err.is_some() {
|
||||
self.error.set(err);
|
||||
}
|
||||
self.read_task.wake();
|
||||
self.write_task.wake();
|
||||
self.dispatch_task.wake();
|
||||
self.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
|
||||
self.notify_disconnect();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -195,7 +211,7 @@ impl Io {
|
|||
let io_ref = IoRef(inner);
|
||||
|
||||
// start io tasks
|
||||
io.start(ReadState(io_ref.clone()), WriteState(io_ref.clone()));
|
||||
io.start(ReadContext(io_ref.clone()), WriteContext(io_ref.clone()));
|
||||
|
||||
Io(io_ref, FilterItem::Ptr(Box::into_raw(filter)))
|
||||
}
|
||||
|
@ -218,8 +234,8 @@ impl<F> Io<F> {
|
|||
|
||||
#[inline]
|
||||
/// Set io disconnect timeout in secs
|
||||
pub fn set_disconnect_timeout(&self, timeout: Seconds) {
|
||||
self.0 .0.disconnect_timeout.set(timeout.into());
|
||||
pub fn set_disconnect_timeout(&self, timeout: Millis) {
|
||||
self.0 .0.disconnect_timeout.set(timeout);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -242,31 +258,6 @@ impl<F> Io<F> {
|
|||
pub fn register_dispatcher(&self, cx: &mut Context<'_>) {
|
||||
self.0 .0.dispatch_task.register(cx.waker());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Mark dispatcher as stopped
|
||||
pub fn dispatcher_stopped(&self) {
|
||||
self.0 .0.insert_flags(Flags::DSP_STOP);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Gracefully shutdown read and write io tasks
|
||||
pub fn init_shutdown(&self, cx: &mut Context<'_>) {
|
||||
let flags = self.0 .0.flags.get();
|
||||
|
||||
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
|
||||
log::trace!("initiate io shutdown {:?}", flags);
|
||||
self.0 .0.insert_flags(Flags::IO_FILTERS);
|
||||
if let Err(err) = self.0 .0.shutdown_filters(&self.0) {
|
||||
self.0 .0.error.set(Some(err));
|
||||
self.0 .0.insert_flags(Flags::IO_ERR);
|
||||
}
|
||||
|
||||
self.0 .0.read_task.wake();
|
||||
self.0 .0.write_task.wake();
|
||||
self.0 .0.dispatch_task.register(cx.waker());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IoRef {
|
||||
|
@ -284,9 +275,9 @@ impl IoRef {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
/// Check if io error occured in read or write task
|
||||
pub fn is_io_err(&self) -> bool {
|
||||
self.0.is_io_err()
|
||||
/// Check if io is still active
|
||||
pub fn is_io_open(&self) -> bool {
|
||||
self.0.is_io_open()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -304,10 +295,13 @@ impl IoRef {
|
|||
#[inline]
|
||||
/// Check if io stream is closed
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.0
|
||||
.flags
|
||||
.get()
|
||||
.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::DSP_STOP)
|
||||
self.0.flags.get().intersects(
|
||||
Flags::IO_ERR
|
||||
| Flags::IO_SHUTDOWN
|
||||
| Flags::IO_CLOSED
|
||||
| Flags::IO_FILTERS
|
||||
| Flags::DSP_STOP,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -316,6 +310,12 @@ impl IoRef {
|
|||
self.0.error.take()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Mark dispatcher as stopped
|
||||
pub fn stop_dispatcher(&self) {
|
||||
self.0.insert_flags(Flags::DSP_STOP);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Reset keep-alive error
|
||||
pub fn reset_keepalive(&self) {
|
||||
|
@ -360,9 +360,15 @@ impl IoRef {
|
|||
pub fn on_disconnect(&self) -> OnDisconnect {
|
||||
OnDisconnect::new(self.0.clone(), self.0.flags.get().contains(Flags::IO_ERR))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Query specific data
|
||||
pub fn query<T: 'static>(&self) -> Option<T> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Io<F> {
|
||||
impl IoRef {
|
||||
#[inline]
|
||||
/// Read incoming io stream and decode codec item.
|
||||
pub async fn next<U>(
|
||||
|
@ -375,18 +381,18 @@ impl<F> Io<F> {
|
|||
let read = self.read();
|
||||
|
||||
loop {
|
||||
let mut buf = self.0 .0.read_buf.take();
|
||||
let mut buf = self.0.read_buf.take();
|
||||
let item = if let Some(ref mut buf) = buf {
|
||||
codec.decode(buf)
|
||||
} else {
|
||||
Ok(None)
|
||||
};
|
||||
self.0 .0.read_buf.set(buf);
|
||||
self.0.read_buf.set(buf);
|
||||
|
||||
return match item {
|
||||
Ok(Some(el)) => Ok(Some(el)),
|
||||
Ok(None) => {
|
||||
self.0 .0.remove_flags(Flags::RD_READY);
|
||||
self.0.remove_flags(Flags::RD_READY);
|
||||
if poll_fn(|cx| read.poll_ready(cx))
|
||||
.await
|
||||
.map_err(Either::Right)?
|
||||
|
@ -411,53 +417,53 @@ impl<F> Io<F> {
|
|||
where
|
||||
U: Encoder,
|
||||
{
|
||||
let filter = self.0 .0.filter.get();
|
||||
let filter = self.0.filter.get();
|
||||
let mut buf = filter
|
||||
.get_write_buf()
|
||||
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
|
||||
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
|
||||
|
||||
let is_write_sleep = buf.is_empty();
|
||||
codec.encode(item, &mut buf).map_err(Either::Left)?;
|
||||
filter.release_write_buf(buf).map_err(Either::Right)?;
|
||||
self.0 .0.insert_flags(Flags::WR_WAIT);
|
||||
self.0.insert_flags(Flags::WR_WAIT);
|
||||
if is_write_sleep {
|
||||
self.0 .0.write_task.wake();
|
||||
self.0.write_task.wake();
|
||||
}
|
||||
|
||||
poll_fn(|cx| self.write().poll_flush(cx))
|
||||
poll_fn(|cx| self.write().poll_flush(cx, true))
|
||||
.await
|
||||
.map_err(Either::Right)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Shuts down connection
|
||||
pub async fn shutdown(&self) -> Result<(), io::Error> {
|
||||
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||
Ok(())
|
||||
} else {
|
||||
poll_fn(|cx| {
|
||||
let flags = self.flags();
|
||||
if !flags.contains(Flags::IO_FILTERS) {
|
||||
self.init_shutdown(cx);
|
||||
}
|
||||
/// Shut down connection
|
||||
pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
let flags = self.flags();
|
||||
|
||||
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||
if let Some(err) = self.0 .0.error.take() {
|
||||
Poll::Ready(Err(err))
|
||||
} else {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
} else {
|
||||
self.0 .0.insert_flags(Flags::IO_FILTERS);
|
||||
self.0 .0.dispatch_task.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
})
|
||||
.await
|
||||
if flags.intersects(Flags::IO_ERR | Flags::IO_CLOSED) {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
if !flags.contains(Flags::IO_FILTERS) {
|
||||
self.init_shutdown(cx);
|
||||
}
|
||||
self.0.insert_flags(Flags::IO_FILTERS);
|
||||
|
||||
if let Some(err) = self.0.error.take() {
|
||||
Poll::Ready(Err(err))
|
||||
} else {
|
||||
self.0.dispatch_task.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Shut down connection
|
||||
pub async fn shutdown(&self) -> Result<(), io::Error> {
|
||||
poll_fn(|cx| self.poll_shutdown(cx)).await
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn poll_next<U>(
|
||||
|
@ -468,38 +474,48 @@ impl<F> Io<F> {
|
|||
where
|
||||
U: Decoder,
|
||||
{
|
||||
if self
|
||||
.read()
|
||||
.poll_ready(cx)
|
||||
.map_err(Either::Right)?
|
||||
.is_ready()
|
||||
{
|
||||
let mut buf = self.0 .0.read_buf.take();
|
||||
let item = if let Some(ref mut buf) = buf {
|
||||
codec.decode(buf)
|
||||
} else {
|
||||
Ok(None)
|
||||
};
|
||||
self.0 .0.read_buf.set(buf);
|
||||
let read = self.read();
|
||||
|
||||
match item {
|
||||
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
|
||||
Ok(None) => {
|
||||
if let Poll::Ready(res) =
|
||||
self.read().poll_ready(cx).map_err(Either::Right)?
|
||||
{
|
||||
if res.is_none() {
|
||||
return Poll::Ready(Ok(None));
|
||||
}
|
||||
match read.decode(codec) {
|
||||
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
|
||||
Ok(None) => {
|
||||
if let Poll::Ready(res) = read.poll_ready(cx).map_err(Either::Right)? {
|
||||
if res.is_none() {
|
||||
return Poll::Ready(Ok(None));
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
Err(err) => Poll::Ready(Err(Either::Left(err))),
|
||||
Poll::Pending
|
||||
}
|
||||
} else {
|
||||
Poll::Pending
|
||||
Err(err) => Poll::Ready(Err(Either::Left(err))),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Gracefully shutdown read and write io tasks
|
||||
pub(super) fn init_shutdown(&self, cx: &mut Context<'_>) {
|
||||
let flags = self.0.flags.get();
|
||||
|
||||
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
|
||||
log::trace!("initiate io shutdown {:?}", flags);
|
||||
self.0.insert_flags(Flags::IO_FILTERS);
|
||||
if let Err(err) = self.0.shutdown_filters(self) {
|
||||
self.0.error.set(Some(err));
|
||||
self.0.insert_flags(Flags::IO_ERR);
|
||||
}
|
||||
|
||||
self.0.read_task.wake();
|
||||
self.0.write_task.wake();
|
||||
self.0.dispatch_task.register(cx.waker());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for IoRef {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("IoRef")
|
||||
.field("open", &!self.is_closed())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter> Io<F> {
|
||||
|
@ -576,7 +592,10 @@ impl<F: Filter> Io<F> {
|
|||
|
||||
impl<F> Drop for Io<F> {
|
||||
fn drop(&mut self) {
|
||||
log::trace!("stopping io stream");
|
||||
log::trace!(
|
||||
"io is dropped, force stopping io streams {:?}",
|
||||
self.0.flags()
|
||||
);
|
||||
if let FilterItem::Ptr(p) = self.1 {
|
||||
if p.is_null() {
|
||||
return;
|
||||
|
@ -635,7 +654,7 @@ impl<'a> WriteRef<'a> {
|
|||
///
|
||||
/// Write task must be waken up separately.
|
||||
pub fn enable_backpressure(&self, cx: Option<&mut Context<'_>>) {
|
||||
log::trace!("enable write back-pressure");
|
||||
log::trace!("enable write back-pressure {:?}", cx.is_some());
|
||||
self.0.insert_flags(Flags::WR_BACKPRESSURE);
|
||||
if let Some(cx) = cx {
|
||||
self.0.dispatch_task.register(cx.waker());
|
||||
|
@ -669,7 +688,7 @@ impl<'a> WriteRef<'a> {
|
|||
&self,
|
||||
item: U::Item,
|
||||
codec: &U,
|
||||
) -> Result<bool, Either<<U as Encoder>::Error, io::Error>>
|
||||
) -> Result<bool, <U as Encoder>::Error>
|
||||
where
|
||||
U: Encoder,
|
||||
{
|
||||
|
@ -690,28 +709,44 @@ impl<'a> WriteRef<'a> {
|
|||
}
|
||||
|
||||
// encode item and wake write task
|
||||
let result = codec
|
||||
.encode(item, &mut buf)
|
||||
.map(|_| {
|
||||
if is_write_sleep {
|
||||
self.0.write_task.wake();
|
||||
}
|
||||
buf.len() < hw
|
||||
})
|
||||
.map_err(Either::Left);
|
||||
filter.release_write_buf(buf).map_err(Either::Right)?;
|
||||
Ok(result?)
|
||||
let result = codec.encode(item, &mut buf).map(|_| {
|
||||
if is_write_sleep {
|
||||
self.0.write_task.wake();
|
||||
}
|
||||
buf.len() < hw
|
||||
});
|
||||
if let Err(err) = filter.release_write_buf(buf) {
|
||||
self.0.set_error(Some(err));
|
||||
}
|
||||
result
|
||||
} else {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Wake write task and instruct to write all data.
|
||||
/// Wake write task and instruct to write data.
|
||||
///
|
||||
/// When write task is done wake dispatcher.
|
||||
pub fn poll_flush(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.0.insert_flags(Flags::WR_WAIT);
|
||||
/// If full is true then wake up dispatcher when all data is flushed
|
||||
/// otherwise wake up when size of write buffer is lower than
|
||||
/// buffer max size.
|
||||
pub fn poll_flush(
|
||||
&self,
|
||||
cx: &mut Context<'_>,
|
||||
full: bool,
|
||||
) -> Poll<Result<(), io::Error>> {
|
||||
// check io error
|
||||
if !self.0.is_io_open() {
|
||||
return Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::Other, "disconnected")
|
||||
})));
|
||||
}
|
||||
|
||||
if full {
|
||||
self.0.insert_flags(Flags::WR_WAIT);
|
||||
} else {
|
||||
self.0.insert_flags(Flags::WR_BACKPRESSURE);
|
||||
}
|
||||
|
||||
if let Some(buf) = self.0.write_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
|
@ -722,14 +757,16 @@ impl<'a> WriteRef<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
if self.0.is_io_err() {
|
||||
Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::Other, "disconnected")
|
||||
})))
|
||||
} else {
|
||||
self.0.dispatch_task.register(cx.waker());
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
// self.0.dispatch_task.register(cx.waker());
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Wake write task and instruct to write data.
|
||||
///
|
||||
/// This is async version of .poll_flush() method.
|
||||
pub async fn flush(&self, full: bool) -> Result<(), io::Error> {
|
||||
poll_fn(|cx| self.poll_flush(cx, full)).await
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -834,7 +871,7 @@ impl<'a> ReadRef<'a> {
|
|||
let mut flags = self.0.flags.get();
|
||||
let ready = flags.contains(Flags::RD_READY);
|
||||
|
||||
if self.0.is_io_err() {
|
||||
if !self.0.is_io_open() {
|
||||
if let Some(err) = self.0.error.take() {
|
||||
Poll::Ready(Err(err))
|
||||
} else {
|
||||
|
@ -843,7 +880,6 @@ impl<'a> ReadRef<'a> {
|
|||
} else if ready {
|
||||
Poll::Ready(Ok(Some(())))
|
||||
} else {
|
||||
flags.remove(Flags::RD_READY);
|
||||
if flags.contains(Flags::RD_BUF_FULL) {
|
||||
log::trace!("read back-pressure is enabled, wake io task");
|
||||
flags.remove(Flags::RD_BUF_FULL);
|
||||
|
@ -939,7 +975,6 @@ mod tests {
|
|||
|
||||
#[ntex::test]
|
||||
async fn utils() {
|
||||
env_logger::init();
|
||||
let (client, server) = IoTest::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write(TEXT);
|
||||
|
@ -1041,7 +1076,7 @@ mod tests {
|
|||
in_bytes: Rc<Cell<usize>>,
|
||||
out_bytes: Rc<Cell<usize>>,
|
||||
}
|
||||
impl<F: ReadFilter + WriteFilter> Filter for Counter<F> {
|
||||
impl<F: ReadFilter + WriteFilter + 'static> Filter for Counter<F> {
|
||||
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
|
|
@ -4,9 +4,9 @@ use ntex_bytes::{BytesMut, PoolRef};
|
|||
|
||||
use super::{state::Flags, IoRef, WriteReadiness};
|
||||
|
||||
pub struct ReadState(pub(super) IoRef);
|
||||
pub struct ReadContext(pub(super) IoRef);
|
||||
|
||||
impl ReadState {
|
||||
impl ReadContext {
|
||||
#[inline]
|
||||
pub fn memory_pool(&self) -> PoolRef {
|
||||
self.0 .0.pool.get()
|
||||
|
@ -60,9 +60,9 @@ impl ReadState {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct WriteState(pub(super) IoRef);
|
||||
pub struct WriteContext(pub(super) IoRef);
|
||||
|
||||
impl WriteState {
|
||||
impl WriteContext {
|
||||
#[inline]
|
||||
pub fn memory_pool(&self) -> PoolRef {
|
||||
self.0 .0.pool.get()
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
use std::cell::{Cell, RefCell};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use std::{cmp, fmt, io, mem};
|
||||
use std::{cmp, fmt, future::Future, io, mem, pin::Pin, rc::Rc};
|
||||
|
||||
use ntex_bytes::{BufMut, BytesMut};
|
||||
use ntex_bytes::{Buf, BufMut, BytesMut};
|
||||
use ntex_util::future::poll_fn;
|
||||
use ntex_util::time::{sleep, Millis};
|
||||
use ntex_util::time::{sleep, Millis, Sleep};
|
||||
|
||||
use crate::{IoStream, ReadContext, WriteContext, WriteReadiness};
|
||||
|
||||
#[derive(Default)]
|
||||
struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
|
||||
|
@ -441,138 +443,181 @@ mod tokio {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tokio"))]
|
||||
mod non_tokio {
|
||||
impl IoStream for IoTest {
|
||||
fn start(self, read: ReadState, write: WriteState) {
|
||||
let io = Rc::new(self);
|
||||
impl IoStream for IoTest {
|
||||
fn start(self, read: ReadContext, write: WriteContext) {
|
||||
let io = Rc::new(self);
|
||||
|
||||
ntex_util::spawn(ReadTask {
|
||||
io: io.clone(),
|
||||
state: read,
|
||||
});
|
||||
ntex_util::spawn(WriteTask {
|
||||
io,
|
||||
state: write,
|
||||
st: IoWriteState::Processing,
|
||||
});
|
||||
}
|
||||
ntex_util::spawn(ReadTask {
|
||||
io: io.clone(),
|
||||
state: read,
|
||||
});
|
||||
ntex_util::spawn(WriteTask {
|
||||
io,
|
||||
state: write,
|
||||
st: IoWriteState::Processing(None),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Read io task
|
||||
struct ReadTask {
|
||||
io: Rc<IoTest>,
|
||||
state: ReadState,
|
||||
}
|
||||
/// Read io task
|
||||
struct ReadTask {
|
||||
io: Rc<IoTest>,
|
||||
state: ReadContext,
|
||||
}
|
||||
|
||||
impl Future for ReadTask {
|
||||
type Output = ();
|
||||
impl Future for ReadTask {
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.as_ref();
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.as_ref();
|
||||
|
||||
match this.state.poll_ready(cx) {
|
||||
Poll::Ready(Err(())) => {
|
||||
log::trace!("read task is instructed to terminate");
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Ready(Ok(())) => {
|
||||
let io = &this.io;
|
||||
let pool = this.state.memory_pool();
|
||||
let mut buf = self.state.get_read_buf();
|
||||
let (hw, lw) = pool.read_params().unpack();
|
||||
match this.state.poll_ready(cx) {
|
||||
Poll::Ready(Err(())) => {
|
||||
log::trace!("read task is instructed to terminate");
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Ready(Ok(())) => {
|
||||
let io = &this.io;
|
||||
let pool = this.state.memory_pool();
|
||||
let mut buf = self.state.get_read_buf();
|
||||
let (hw, lw) = pool.read_params().unpack();
|
||||
|
||||
// read data from socket
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
}
|
||||
|
||||
match io.poll_read_buf(cx, &mut buf) {
|
||||
Poll::Pending => {
|
||||
log::trace!("no more data in io stream");
|
||||
break;
|
||||
}
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!("io stream is disconnected");
|
||||
this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(None);
|
||||
return Poll::Ready(());
|
||||
} else {
|
||||
new_bytes += n;
|
||||
if buf.len() > hw {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::trace!("read task failed on io {:?}", err);
|
||||
this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(Some(err));
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
// read data from socket
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
}
|
||||
|
||||
this.state.release_read_buf(buf, new_bytes);
|
||||
Poll::Pending
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum IoWriteState {
|
||||
Processing,
|
||||
Shutdown(Option<Sleep>, Shutdown),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Shutdown {
|
||||
None,
|
||||
Flushed,
|
||||
Stopping,
|
||||
}
|
||||
|
||||
/// Write io task
|
||||
struct WriteTask {
|
||||
st: IoWriteState,
|
||||
io: Rc<IoTest>,
|
||||
state: WriteState,
|
||||
}
|
||||
|
||||
impl Future for WriteTask {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.as_mut().get_mut();
|
||||
|
||||
match this.st {
|
||||
IoWriteState::Processing => {
|
||||
match this.state.poll_ready(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
// flush framed instance
|
||||
match flush_io(&this.io, &this.state, cx) {
|
||||
Poll::Pending | Poll::Ready(true) => Poll::Pending,
|
||||
Poll::Ready(false) => Poll::Ready(()),
|
||||
match io.poll_read_buf(cx, &mut buf) {
|
||||
Poll::Pending => {
|
||||
log::trace!("no more data in io stream");
|
||||
break;
|
||||
}
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!("io stream is disconnected");
|
||||
let _ = this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(None);
|
||||
return Poll::Ready(());
|
||||
} else {
|
||||
new_bytes += n;
|
||||
if buf.len() > hw {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(WriteReadiness::Shutdown)) => {
|
||||
log::trace!("write task is instructed to shutdown");
|
||||
|
||||
this.st = IoWriteState::Shutdown(
|
||||
this.state.disconnect_timeout().map(sleep),
|
||||
Shutdown::None,
|
||||
);
|
||||
self.poll(cx)
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::trace!("read task failed on io {:?}", err);
|
||||
let _ = this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(Some(err));
|
||||
return Poll::Ready(());
|
||||
}
|
||||
Poll::Ready(Err(WriteReadiness::Terminate)) => {
|
||||
log::trace!("write task is instructed to terminate");
|
||||
}
|
||||
}
|
||||
|
||||
let _ = this.state.release_read_buf(buf, new_bytes);
|
||||
Poll::Pending
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum IoWriteState {
|
||||
Processing(Option<Sleep>),
|
||||
Shutdown(Option<Sleep>, Shutdown),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Shutdown {
|
||||
None,
|
||||
Flushed,
|
||||
Stopping,
|
||||
}
|
||||
|
||||
/// Write io task
|
||||
struct WriteTask {
|
||||
st: IoWriteState,
|
||||
io: Rc<IoTest>,
|
||||
state: WriteContext,
|
||||
}
|
||||
|
||||
impl Future for WriteTask {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.as_mut().get_mut();
|
||||
|
||||
match this.st {
|
||||
IoWriteState::Processing(ref mut delay) => {
|
||||
match this.state.poll_ready(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
// flush framed instance
|
||||
match flush_io(&this.io, &this.state, cx) {
|
||||
Poll::Pending | Poll::Ready(true) => Poll::Pending,
|
||||
Poll::Ready(false) => Poll::Ready(()),
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(WriteReadiness::Timeout(time))) => {
|
||||
if delay.is_none() {
|
||||
*delay = Some(sleep(time));
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
Poll::Ready(Err(WriteReadiness::Shutdown(time))) => {
|
||||
log::trace!("write task is instructed to shutdown");
|
||||
|
||||
let timeout = if let Some(delay) = delay.take() {
|
||||
delay
|
||||
} else {
|
||||
sleep(time)
|
||||
};
|
||||
|
||||
this.st = IoWriteState::Shutdown(Some(timeout), Shutdown::None);
|
||||
self.poll(cx)
|
||||
}
|
||||
Poll::Ready(Err(WriteReadiness::Terminate)) => {
|
||||
log::trace!("write task is instructed to terminate");
|
||||
// shutdown WRITE side
|
||||
this.io
|
||||
.local
|
||||
.lock()
|
||||
.unwrap()
|
||||
.borrow_mut()
|
||||
.flags
|
||||
.insert(Flags::CLOSED);
|
||||
this.state.close(None);
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
|
||||
// close WRITE side and wait for disconnect on read side.
|
||||
// use disconnect timeout, otherwise it could hang forever.
|
||||
loop {
|
||||
match st {
|
||||
Shutdown::None => {
|
||||
// flush write buffer
|
||||
match flush_io(&this.io, &this.state, cx) {
|
||||
Poll::Ready(true) => {
|
||||
*st = Shutdown::Flushed;
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(false) => {
|
||||
log::trace!(
|
||||
"write task is closed with err during flush"
|
||||
);
|
||||
return Poll::Ready(());
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
Shutdown::Flushed => {
|
||||
// shutdown WRITE side
|
||||
this.io
|
||||
.local
|
||||
|
@ -581,143 +626,102 @@ mod non_tokio {
|
|||
.borrow_mut()
|
||||
.flags
|
||||
.insert(Flags::CLOSED);
|
||||
this.state.close(None);
|
||||
Poll::Ready(())
|
||||
*st = Shutdown::Stopping;
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
|
||||
// close WRITE side and wait for disconnect on read side.
|
||||
// use disconnect timeout, otherwise it could hang forever.
|
||||
loop {
|
||||
match st {
|
||||
Shutdown::None => {
|
||||
// flush write buffer
|
||||
match flush_io(&this.io, &this.state, cx) {
|
||||
Poll::Ready(true) => {
|
||||
*st = Shutdown::Flushed;
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(false) => {
|
||||
log::trace!(
|
||||
"write task is closed with err during flush"
|
||||
);
|
||||
Shutdown::Stopping => {
|
||||
// read until 0 or err
|
||||
let io = &this.io;
|
||||
loop {
|
||||
let mut buf = BytesMut::new();
|
||||
match io.poll_read_buf(cx, &mut buf) {
|
||||
Poll::Ready(Err(e)) => {
|
||||
this.state.close(Some(e));
|
||||
log::trace!("write task is stopped");
|
||||
return Poll::Ready(());
|
||||
}
|
||||
Poll::Ready(Ok(n)) if n == 0 => {
|
||||
this.state.close(None);
|
||||
log::trace!("write task is stopped");
|
||||
return Poll::Ready(());
|
||||
}
|
||||
Poll::Pending => break,
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
Shutdown::Flushed => {
|
||||
// shutdown WRITE side
|
||||
this.io
|
||||
.local
|
||||
.lock()
|
||||
.unwrap()
|
||||
.borrow_mut()
|
||||
.flags
|
||||
.insert(Flags::CLOSED);
|
||||
*st = Shutdown::Stopping;
|
||||
continue;
|
||||
}
|
||||
Shutdown::Stopping => {
|
||||
// read until 0 or err
|
||||
let io = &this.io;
|
||||
loop {
|
||||
let mut buf = BytesMut::new();
|
||||
match io.poll_read_buf(cx, &mut buf) {
|
||||
Poll::Ready(Err(e)) => {
|
||||
this.state.close(Some(e));
|
||||
log::trace!("write task is stopped");
|
||||
return Poll::Ready(());
|
||||
}
|
||||
Poll::Ready(Ok(n)) if n == 0 => {
|
||||
this.state.close(None);
|
||||
log::trace!("write task is stopped");
|
||||
return Poll::Ready(());
|
||||
}
|
||||
Poll::Pending => break,
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// disconnect timeout
|
||||
if let Some(ref delay) = delay {
|
||||
if delay.poll_elapsed(cx).is_pending() {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
log::trace!("write task is stopped after delay");
|
||||
this.state.close(None);
|
||||
return Poll::Ready(());
|
||||
}
|
||||
|
||||
// disconnect timeout
|
||||
if let Some(ref delay) = delay {
|
||||
if delay.poll_elapsed(cx).is_pending() {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
log::trace!("write task is stopped after delay");
|
||||
this.state.close(None);
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush write buffer to underlying I/O stream.
|
||||
pub(super) fn flush_io(
|
||||
io: &IoTest,
|
||||
state: &WriteState,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<bool> {
|
||||
let mut buf = if let Some(buf) = state.get_write_buf() {
|
||||
buf
|
||||
} else {
|
||||
return Poll::Ready(true);
|
||||
};
|
||||
let len = buf.len();
|
||||
let pool = state.memory_pool();
|
||||
/// Flush write buffer to underlying I/O stream.
|
||||
pub(super) fn flush_io(
|
||||
io: &IoTest,
|
||||
state: &WriteContext,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<bool> {
|
||||
let mut buf = if let Some(buf) = state.get_write_buf() {
|
||||
buf
|
||||
} else {
|
||||
return Poll::Ready(true);
|
||||
};
|
||||
let len = buf.len();
|
||||
|
||||
if len != 0 {
|
||||
log::trace!("flushing framed transport: {}", len);
|
||||
if len != 0 {
|
||||
log::trace!("flushing framed transport: {}", len);
|
||||
|
||||
let mut written = 0;
|
||||
while written < len {
|
||||
match io.poll_write_buf(cx, &buf[written..]) {
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!(
|
||||
"disconnected during flush, written {}",
|
||||
written
|
||||
);
|
||||
pool.release_write_buf(buf);
|
||||
state.close(Some(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"failed to write frame to transport",
|
||||
)));
|
||||
return Poll::Ready(false);
|
||||
} else {
|
||||
written += n
|
||||
}
|
||||
}
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Err(e)) => {
|
||||
log::trace!("error during flush: {}", e);
|
||||
pool.release_write_buf(buf);
|
||||
state.close(Some(e));
|
||||
let mut written = 0;
|
||||
while written < len {
|
||||
match io.poll_write_buf(cx, &buf[written..]) {
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!("disconnected during flush, written {}", written);
|
||||
let _ = state.release_write_buf(buf);
|
||||
state.close(Some(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"failed to write frame to transport",
|
||||
)));
|
||||
return Poll::Ready(false);
|
||||
} else {
|
||||
written += n
|
||||
}
|
||||
}
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Err(e)) => {
|
||||
log::trace!("error during flush: {}", e);
|
||||
let _ = state.release_write_buf(buf);
|
||||
state.close(Some(e));
|
||||
return Poll::Ready(false);
|
||||
}
|
||||
}
|
||||
log::trace!("flushed {} bytes", written);
|
||||
|
||||
// remove written data
|
||||
if written == len {
|
||||
buf.clear();
|
||||
state.release_write_buf(buf);
|
||||
Poll::Ready(true)
|
||||
} else {
|
||||
buf.advance(written);
|
||||
state.release_write_buf(buf);
|
||||
Poll::Pending
|
||||
}
|
||||
} else {
|
||||
Poll::Ready(true)
|
||||
}
|
||||
log::trace!("flushed {} bytes", written);
|
||||
|
||||
// remove written data
|
||||
if written == len {
|
||||
buf.clear();
|
||||
let _ = state.release_write_buf(buf);
|
||||
Poll::Ready(true)
|
||||
} else {
|
||||
buf.advance(written);
|
||||
let _ = state.release_write_buf(buf);
|
||||
Poll::Pending
|
||||
}
|
||||
} else {
|
||||
Poll::Ready(true)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,15 +3,12 @@ use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc};
|
|||
|
||||
use ntex_bytes::{Buf, BufMut};
|
||||
use ntex_util::time::{sleep, Sleep};
|
||||
use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf};
|
||||
use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf, net::TcpStream};
|
||||
|
||||
use super::{IoStream, ReadState, WriteReadiness, WriteState};
|
||||
use super::{IoStream, ReadContext, WriteContext, WriteReadiness};
|
||||
|
||||
impl<T> IoStream for T
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
fn start(self, read: ReadState, write: WriteState) {
|
||||
impl IoStream for TcpStream {
|
||||
fn start(self, read: ReadContext, write: WriteContext) {
|
||||
let io = Rc::new(RefCell::new(self));
|
||||
|
||||
ntex_util::spawn(ReadTask::new(io.clone(), read));
|
||||
|
@ -19,26 +16,29 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Read io task
|
||||
struct ReadTask<T> {
|
||||
io: Rc<RefCell<T>>,
|
||||
state: ReadState,
|
||||
#[cfg(unix)]
|
||||
impl IoStream for tok_io::net::UnixStream {
|
||||
fn start(self, _read: ReadContext, _write: WriteContext) {
|
||||
let _io = Rc::new(RefCell::new(self));
|
||||
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ReadTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
/// Read io task
|
||||
struct ReadTask {
|
||||
io: Rc<RefCell<TcpStream>>,
|
||||
state: ReadContext,
|
||||
}
|
||||
|
||||
impl ReadTask {
|
||||
/// Create new read io task
|
||||
fn new(io: Rc<RefCell<T>>, state: ReadState) -> Self {
|
||||
fn new(io: Rc<RefCell<TcpStream>>, state: ReadContext) -> Self {
|
||||
Self { io, state }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for ReadTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
impl Future for ReadTask {
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
|
@ -119,18 +119,15 @@ enum Shutdown {
|
|||
}
|
||||
|
||||
/// Write io task
|
||||
struct WriteTask<T> {
|
||||
struct WriteTask {
|
||||
st: IoWriteState,
|
||||
io: Rc<RefCell<T>>,
|
||||
state: WriteState,
|
||||
io: Rc<RefCell<TcpStream>>,
|
||||
state: WriteContext,
|
||||
}
|
||||
|
||||
impl<T> WriteTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
impl WriteTask {
|
||||
/// Create new write io task
|
||||
fn new(io: Rc<RefCell<T>>, state: WriteState) -> Self {
|
||||
fn new(io: Rc<RefCell<TcpStream>>, state: WriteContext) -> Self {
|
||||
Self {
|
||||
io,
|
||||
state,
|
||||
|
@ -139,10 +136,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> Future for WriteTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
impl Future for WriteTask {
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
|
@ -272,7 +266,7 @@ where
|
|||
/// Flush write buffer to underlying I/O stream.
|
||||
pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
|
||||
io: &mut T,
|
||||
state: &WriteState,
|
||||
state: &WriteContext,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<bool> {
|
||||
let mut buf = if let Some(buf) = state.get_write_buf() {
|
||||
|
@ -284,12 +278,14 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
|
|||
let pool = state.memory_pool();
|
||||
|
||||
if len != 0 {
|
||||
// log::trace!("flushing framed transport: {:?}", buf);
|
||||
//log::trace!("flushing framed transport: {:?}", buf);
|
||||
|
||||
let mut written = 0;
|
||||
while written < len {
|
||||
match Pin::new(&mut *io).poll_write(cx, &buf[written..]) {
|
||||
Poll::Pending => break,
|
||||
Poll::Pending => {
|
||||
break;
|
||||
}
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!("Disconnected during flush, written {}", written);
|
||||
|
@ -311,7 +307,7 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
|
|||
}
|
||||
}
|
||||
}
|
||||
// log::trace!("flushed {} bytes", written);
|
||||
//log::trace!("flushed {} bytes", written);
|
||||
|
||||
// remove written data
|
||||
let result = if written == len {
|
||||
|
|
|
@ -22,6 +22,6 @@ ntex-util = "0.1.2"
|
|||
openssl = "0.10.32"
|
||||
|
||||
[dev-dependencies]
|
||||
ntex = { version = "0.4.14", features = ["openssl"] }
|
||||
ntex = { version = "0.5.0", features = ["openssl"] }
|
||||
futures = "0.3"
|
||||
env_logger = "0.9"
|
||||
|
|
|
@ -10,6 +10,13 @@ use ntex_io::{
|
|||
use ntex_util::{future::poll_fn, time, time::Millis};
|
||||
use openssl::ssl::{self, SslStream};
|
||||
|
||||
/// Selected alpn protocol
|
||||
pub enum AlpnHttpProtocol {
|
||||
Http1,
|
||||
Http2,
|
||||
}
|
||||
|
||||
/// An implementation of SSL streams
|
||||
pub struct SslFilter<F> {
|
||||
inner: RefCell<SslStream<IoInner<F>>>,
|
||||
}
|
||||
|
@ -191,7 +198,7 @@ impl SslAcceptor {
|
|||
/// Set handshake timeout.
|
||||
///
|
||||
/// Default is set to 5 seconds.
|
||||
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
|
||||
pub fn timeout<U: Into<Millis>>(&mut self, timeout: U) -> &mut Self {
|
||||
self.timeout = timeout.into();
|
||||
self
|
||||
}
|
||||
|
@ -209,7 +216,7 @@ impl Clone for SslAcceptor {
|
|||
impl<F: Filter + 'static> FilterFactory<F> for SslAcceptor {
|
||||
type Filter = SslFilter<F>;
|
||||
|
||||
type Error = io::Error;
|
||||
type Error = Box<dyn Error>;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, Self::Error>>>>;
|
||||
|
||||
fn create(self, st: Io<F>) -> Self::Future {
|
||||
|
@ -225,8 +232,7 @@ impl<F: Filter + 'static> FilterFactory<F> for SslAcceptor {
|
|||
read_buf: None,
|
||||
write_buf: None,
|
||||
};
|
||||
let ssl_stream =
|
||||
ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?;
|
||||
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
|
||||
|
||||
Ok(SslFilter {
|
||||
inner: RefCell::new(ssl_stream),
|
||||
|
@ -234,9 +240,9 @@ impl<F: Filter + 'static> FilterFactory<F> for SslAcceptor {
|
|||
})?;
|
||||
|
||||
poll_fn(|cx| {
|
||||
let _ = st.write().poll_flush(cx)?;
|
||||
let _ = st.write().poll_flush(cx, true)?;
|
||||
handle_result(st.filter().inner.borrow_mut().accept(), &st, cx)
|
||||
.map_err(map_to_ioerr)
|
||||
.map_err(Into::<Box<dyn Error>>::into)
|
||||
})
|
||||
.await?;
|
||||
|
||||
|
@ -244,7 +250,7 @@ impl<F: Filter + 'static> FilterFactory<F> for SslAcceptor {
|
|||
})
|
||||
.await
|
||||
.map_err(|_| {
|
||||
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout")
|
||||
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into()
|
||||
})
|
||||
.and_then(|item| item)
|
||||
})
|
||||
|
@ -265,7 +271,7 @@ impl SslConnector {
|
|||
impl<F: Filter + 'static> FilterFactory<F> for SslConnector {
|
||||
type Filter = SslFilter<F>;
|
||||
|
||||
type Error = io::Error;
|
||||
type Error = Box<dyn Error>;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, Self::Error>>>>;
|
||||
|
||||
fn create(self, st: Io<F>) -> Self::Future {
|
||||
|
@ -277,8 +283,7 @@ impl<F: Filter + 'static> FilterFactory<F> for SslConnector {
|
|||
read_buf: None,
|
||||
write_buf: None,
|
||||
};
|
||||
let ssl_stream =
|
||||
ssl::SslStream::new(ssl, inner).map_err(map_to_ioerr)?;
|
||||
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
|
||||
|
||||
Ok(SslFilter {
|
||||
inner: RefCell::new(ssl_stream),
|
||||
|
@ -286,9 +291,9 @@ impl<F: Filter + 'static> FilterFactory<F> for SslConnector {
|
|||
})?;
|
||||
|
||||
poll_fn(|cx| {
|
||||
let _ = st.write().poll_flush(cx)?;
|
||||
let _ = st.write().poll_flush(cx, true)?;
|
||||
handle_result(st.filter().inner.borrow_mut().connect(), &st, cx)
|
||||
.map_err(map_to_ioerr)
|
||||
.map_err(Into::<Box<dyn Error>>::into)
|
||||
})
|
||||
.await?;
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex"
|
||||
version = "0.4.14"
|
||||
version = "0.5.0"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "Framework for composable network services"
|
||||
readme = "README.md"
|
||||
|
|
|
@ -1,148 +0,0 @@
|
|||
use std::task::{Context, Poll};
|
||||
use std::{future::Future, pin::Pin};
|
||||
|
||||
use crate::io::Io;
|
||||
use crate::service::{Service, ServiceFactory};
|
||||
use crate::util::{PoolId, PoolRef, Ready};
|
||||
|
||||
use super::service::ConnectServiceResponse;
|
||||
use super::{Address, Connect, ConnectError, Connector};
|
||||
|
||||
pub struct IoConnector<T> {
|
||||
inner: Connector<T>,
|
||||
pool: PoolRef,
|
||||
}
|
||||
|
||||
impl<T> IoConnector<T> {
|
||||
/// Construct new connect service with custom dns resolver
|
||||
pub fn new() -> Self {
|
||||
IoConnector {
|
||||
inner: Connector::new(),
|
||||
pool: PoolId::P0.pool_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set memory pool.
|
||||
///
|
||||
/// Use specified memory pool for memory allocations. By default P0
|
||||
/// memory pool is used.
|
||||
pub fn memory_pool(mut self, id: PoolId) -> Self {
|
||||
self.pool = id.pool_ref();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> IoConnector<T> {
|
||||
/// Resolve and connect to remote host
|
||||
pub fn connect<U>(&self, message: U) -> IoConnectServiceResponse<T>
|
||||
where
|
||||
Connect<T>: From<U>,
|
||||
{
|
||||
IoConnectServiceResponse {
|
||||
inner: self.inner.call(message.into()),
|
||||
pool: self.pool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for IoConnector<T> {
|
||||
fn default() -> Self {
|
||||
IoConnector::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for IoConnector<T> {
|
||||
fn clone(&self) -> Self {
|
||||
IoConnector {
|
||||
inner: self.inner.clone(),
|
||||
pool: self.pool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> ServiceFactory for IoConnector<T> {
|
||||
type Request = Connect<T>;
|
||||
type Response = Io;
|
||||
type Error = ConnectError;
|
||||
type Config = ();
|
||||
type Service = IoConnector<T>;
|
||||
type InitError = ();
|
||||
type Future = Ready<Self::Service, Self::InitError>;
|
||||
|
||||
#[inline]
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
Ready::Ok(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> Service for IoConnector<T> {
|
||||
type Request = Connect<T>;
|
||||
type Response = Io;
|
||||
type Error = ConnectError;
|
||||
type Future = IoConnectServiceResponse<T>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn call(&self, req: Connect<T>) -> Self::Future {
|
||||
self.connect(req)
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub struct IoConnectServiceResponse<T: Address> {
|
||||
inner: ConnectServiceResponse<T>,
|
||||
pool: PoolRef,
|
||||
}
|
||||
|
||||
impl<T: Address> Future for IoConnectServiceResponse<T> {
|
||||
type Output = Result<Io, ConnectError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match Pin::new(&mut self.inner).poll(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Ok(stream)) => {
|
||||
Poll::Ready(Ok(Io::with_memory_pool(stream, self.pool)))
|
||||
}
|
||||
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[crate::rt_test]
|
||||
async fn test_connect() {
|
||||
let server = crate::server::test_server(|| {
|
||||
crate::service::fn_service(|_| async { Ok::<_, ()>(()) })
|
||||
});
|
||||
|
||||
let srv = IoConnector::default();
|
||||
let result = srv.connect("").await;
|
||||
assert!(result.is_err());
|
||||
let result = srv.connect("localhost:99999").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
let srv = IoConnector::default();
|
||||
let result = srv.connect(format!("{}", server.addr())).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![
|
||||
format!("127.0.0.1:{}", server.addr().port() - 1)
|
||||
.parse()
|
||||
.unwrap(),
|
||||
server.addr(),
|
||||
]);
|
||||
let result = crate::connect::connect(msg).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let msg = Connect::new(server.addr());
|
||||
let result = crate::connect::connect(msg).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
|
@ -2,7 +2,6 @@
|
|||
use std::future::Future;
|
||||
|
||||
mod error;
|
||||
mod io;
|
||||
mod message;
|
||||
mod resolve;
|
||||
mod service;
|
||||
|
@ -13,19 +12,18 @@ mod uri;
|
|||
#[cfg(feature = "openssl")]
|
||||
pub mod openssl;
|
||||
|
||||
#[cfg(feature = "rustls")]
|
||||
pub mod rustls;
|
||||
|
||||
use crate::rt::net::TcpStream;
|
||||
//#[cfg(feature = "rustls")]
|
||||
//pub mod rustls;
|
||||
|
||||
pub use self::error::ConnectError;
|
||||
pub use self::io::IoConnector;
|
||||
pub use self::message::{Address, Connect};
|
||||
pub use self::resolve::Resolver;
|
||||
pub use self::service::Connector;
|
||||
|
||||
use crate::io::Io;
|
||||
|
||||
/// Resolve and connect to remote host
|
||||
pub fn connect<T, U>(message: U) -> impl Future<Output = Result<TcpStream, ConnectError>>
|
||||
pub fn connect<T, U>(message: U) -> impl Future<Output = Result<Io, ConnectError>>
|
||||
where
|
||||
T: Address + 'static,
|
||||
Connect<T>: From<U>,
|
||||
|
|
|
@ -2,14 +2,12 @@ use std::{future::Future, io, pin::Pin, task::Context, task::Poll};
|
|||
|
||||
use ntex_openssl::{SslConnector as IoSslConnector, SslFilter};
|
||||
pub use open_ssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
|
||||
pub use tokio_openssl::SslStream;
|
||||
|
||||
use crate::io::{DefaultFilter, Io};
|
||||
use crate::rt::net::TcpStream;
|
||||
use crate::service::{Service, ServiceFactory};
|
||||
use crate::util::Ready;
|
||||
|
||||
use super::{Address, Connect, ConnectError, Connector, IoConnector as BaseIoConnector};
|
||||
use super::{Address, Connect, ConnectError, Connector};
|
||||
|
||||
pub struct OpensslConnector<T> {
|
||||
connector: Connector<T>,
|
||||
|
@ -27,103 +25,6 @@ impl<T> OpensslConnector<T> {
|
|||
}
|
||||
|
||||
impl<T: Address + 'static> OpensslConnector<T> {
|
||||
/// Resolve and connect to remote host
|
||||
pub fn connect<U>(
|
||||
&self,
|
||||
message: U,
|
||||
) -> impl Future<Output = Result<SslStream<TcpStream>, ConnectError>>
|
||||
where
|
||||
Connect<T>: From<U>,
|
||||
{
|
||||
let message = Connect::from(message);
|
||||
let host = message.host().to_string();
|
||||
let conn = self.connector.call(message);
|
||||
let openssl = self.openssl.clone();
|
||||
|
||||
async move {
|
||||
let io = conn.await?;
|
||||
trace!("SSL Handshake start for: {:?}", host);
|
||||
|
||||
match openssl.configure() {
|
||||
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()),
|
||||
Ok(config) => {
|
||||
let config = config
|
||||
.into_ssl(&host)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
let mut io = SslStream::new(config, io)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
match Pin::new(&mut io).connect().await {
|
||||
Ok(_) => {
|
||||
trace!("SSL Handshake success: {:?}", host);
|
||||
Ok(io)
|
||||
}
|
||||
Err(e) => {
|
||||
trace!("SSL Handshake error: {:?}", e);
|
||||
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))
|
||||
.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for OpensslConnector<T> {
|
||||
fn clone(&self) -> Self {
|
||||
OpensslConnector {
|
||||
connector: self.connector.clone(),
|
||||
openssl: self.openssl.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address + 'static> ServiceFactory for OpensslConnector<T> {
|
||||
type Request = Connect<T>;
|
||||
type Response = SslStream<TcpStream>;
|
||||
type Error = ConnectError;
|
||||
type Config = ();
|
||||
type Service = OpensslConnector<T>;
|
||||
type InitError = ();
|
||||
type Future = Ready<Self::Service, Self::InitError>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
Ready::Ok(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address + 'static> Service for OpensslConnector<T> {
|
||||
type Request = Connect<T>;
|
||||
type Response = SslStream<TcpStream>;
|
||||
type Error = ConnectError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&self, req: Connect<T>) -> Self::Future {
|
||||
Box::pin(self.connect(req))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IoConnector<T> {
|
||||
connector: BaseIoConnector<T>,
|
||||
openssl: SslConnector,
|
||||
}
|
||||
|
||||
impl<T> IoConnector<T> {
|
||||
/// Construct new OpensslConnectService factory
|
||||
pub fn new(connector: SslConnector) -> Self {
|
||||
IoConnector {
|
||||
connector: BaseIoConnector::default(),
|
||||
openssl: connector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address + 'static> IoConnector<T> {
|
||||
/// Resolve and connect to remote host
|
||||
pub fn connect<U>(
|
||||
&self,
|
||||
|
@ -164,21 +65,21 @@ impl<T: Address + 'static> IoConnector<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for IoConnector<T> {
|
||||
impl<T> Clone for OpensslConnector<T> {
|
||||
fn clone(&self) -> Self {
|
||||
IoConnector {
|
||||
OpensslConnector {
|
||||
connector: self.connector.clone(),
|
||||
openssl: self.openssl.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address + 'static> ServiceFactory for IoConnector<T> {
|
||||
impl<T: Address + 'static> ServiceFactory for OpensslConnector<T> {
|
||||
type Request = Connect<T>;
|
||||
type Response = Io<SslFilter<DefaultFilter>>;
|
||||
type Error = ConnectError;
|
||||
type Config = ();
|
||||
type Service = IoConnector<T>;
|
||||
type Service = OpensslConnector<T>;
|
||||
type InitError = ();
|
||||
type Future = Ready<Self::Service, Self::InitError>;
|
||||
|
||||
|
@ -187,7 +88,7 @@ impl<T: Address + 'static> ServiceFactory for IoConnector<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Address + 'static> Service for IoConnector<T> {
|
||||
impl<T: Address + 'static> Service for OpensslConnector<T> {
|
||||
type Request = Connect<T>;
|
||||
type Response = Io<SslFilter<DefaultFilter>>;
|
||||
type Error = ConnectError;
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
use std::task::{Context, Poll};
|
||||
use std::{collections::VecDeque, future::Future, io, net::SocketAddr, pin::Pin};
|
||||
|
||||
use crate::io::Io;
|
||||
use crate::rt::net::TcpStream;
|
||||
use crate::service::{Service, ServiceFactory};
|
||||
use crate::util::{Either, Ready};
|
||||
use crate::util::{Either, PoolId, PoolRef, Ready};
|
||||
|
||||
use super::{Address, Connect, ConnectError, Resolver};
|
||||
|
||||
pub struct Connector<T> {
|
||||
resolver: Resolver<T>,
|
||||
pool: PoolRef,
|
||||
}
|
||||
|
||||
impl<T> Connector<T> {
|
||||
|
@ -16,8 +18,18 @@ impl<T> Connector<T> {
|
|||
pub fn new() -> Self {
|
||||
Connector {
|
||||
resolver: Resolver::new(),
|
||||
pool: PoolId::P0.pool_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set memory pool.
|
||||
///
|
||||
/// Use specified memory pool for memory allocations. By default P0
|
||||
/// memory pool is used.
|
||||
pub fn memory_pool(mut self, id: PoolId) -> Self {
|
||||
self.pool = id.pool_ref();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> Connector<T> {
|
||||
|
@ -25,11 +37,14 @@ impl<T: Address> Connector<T> {
|
|||
pub fn connect<U>(
|
||||
&self,
|
||||
message: U,
|
||||
) -> impl Future<Output = Result<TcpStream, ConnectError>>
|
||||
) -> impl Future<Output = Result<Io, ConnectError>>
|
||||
where
|
||||
Connect<T>: From<U>,
|
||||
{
|
||||
ConnectServiceResponse::new(self.resolver.call(message.into()))
|
||||
ConnectServiceResponse {
|
||||
state: ConnectState::Resolve(self.resolver.call(message.into())),
|
||||
pool: self.pool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,13 +58,14 @@ impl<T> Clone for Connector<T> {
|
|||
fn clone(&self) -> Self {
|
||||
Connector {
|
||||
resolver: self.resolver.clone(),
|
||||
pool: self.pool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> ServiceFactory for Connector<T> {
|
||||
type Request = Connect<T>;
|
||||
type Response = TcpStream;
|
||||
type Response = Io;
|
||||
type Error = ConnectError;
|
||||
type Config = ();
|
||||
type Service = Connector<T>;
|
||||
|
@ -64,7 +80,7 @@ impl<T: Address> ServiceFactory for Connector<T> {
|
|||
|
||||
impl<T: Address> Service for Connector<T> {
|
||||
type Request = Connect<T>;
|
||||
type Response = TcpStream;
|
||||
type Response = Io;
|
||||
type Error = ConnectError;
|
||||
type Future = ConnectServiceResponse<T>;
|
||||
|
||||
|
@ -87,18 +103,20 @@ enum ConnectState<T: Address> {
|
|||
#[doc(hidden)]
|
||||
pub struct ConnectServiceResponse<T: Address> {
|
||||
state: ConnectState<T>,
|
||||
pool: PoolRef,
|
||||
}
|
||||
|
||||
impl<T: Address> ConnectServiceResponse<T> {
|
||||
pub(super) fn new(fut: <Resolver<T> as Service>::Future) -> Self {
|
||||
ConnectServiceResponse {
|
||||
Self {
|
||||
state: ConnectState::Resolve(fut),
|
||||
pool: PoolId::P0.pool_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> Future for ConnectServiceResponse<T> {
|
||||
type Output = Result<TcpStream, ConnectError>;
|
||||
type Output = Result<Io, ConnectError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
match self.state {
|
||||
|
@ -126,7 +144,12 @@ impl<T: Address> Future for ConnectServiceResponse<T> {
|
|||
}
|
||||
}
|
||||
},
|
||||
ConnectState::Connect(ref mut fut) => Pin::new(fut).poll(cx),
|
||||
ConnectState::Connect(ref mut fut) => match Pin::new(fut).poll(cx)? {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(stream) => {
|
||||
Poll::Ready(Ok(Io::with_memory_pool(stream, self.pool)))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,901 +0,0 @@
|
|||
//! Framed transport dispatcher
|
||||
use std::{
|
||||
cell::Cell, cell::RefCell, future::Future, pin::Pin, rc::Rc, task::Context,
|
||||
task::Poll, time, time::Instant,
|
||||
};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
|
||||
use crate::framed::{DispatchItem, Read, ReadTask, State, Timer, Write, WriteTask};
|
||||
use crate::service::{IntoService, Service};
|
||||
use crate::time::Seconds;
|
||||
use crate::util::{Either, Pool};
|
||||
|
||||
type Response<U> = <U as Encoder>::Item;
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// Framed dispatcher - is a future that reads frames from Framed object
|
||||
/// and pass then to the service.
|
||||
pub struct Dispatcher<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
|
||||
S::Error: 'static,
|
||||
S::Future: 'static,
|
||||
U: Encoder,
|
||||
U: Decoder,
|
||||
<U as Encoder>::Item: 'static,
|
||||
{
|
||||
service: S,
|
||||
inner: DispatcherInner<S, U>,
|
||||
#[pin]
|
||||
fut: Option<S::Future>,
|
||||
}
|
||||
}
|
||||
|
||||
struct DispatcherInner<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
|
||||
U: Encoder + Decoder,
|
||||
{
|
||||
st: Cell<DispatcherState>,
|
||||
state: State,
|
||||
timer: Timer,
|
||||
ka_timeout: Seconds,
|
||||
ka_updated: Cell<Instant>,
|
||||
error: Cell<Option<S::Error>>,
|
||||
shared: Rc<DispatcherShared<S, U>>,
|
||||
pool: Pool,
|
||||
}
|
||||
|
||||
struct DispatcherShared<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
|
||||
U: Encoder + Decoder,
|
||||
{
|
||||
codec: U,
|
||||
error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
|
||||
inflight: Cell<usize>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
enum DispatcherState {
|
||||
Processing,
|
||||
Backpressure,
|
||||
Stop,
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
enum DispatcherError<S, U> {
|
||||
KeepAlive,
|
||||
Encoder(U),
|
||||
Service(S),
|
||||
}
|
||||
|
||||
enum PollService<U: Encoder + Decoder> {
|
||||
Item(DispatchItem<U>),
|
||||
ServiceError,
|
||||
Ready,
|
||||
}
|
||||
|
||||
impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
|
||||
fn from(err: Either<S, U>) -> Self {
|
||||
match err {
|
||||
Either::Left(err) => DispatcherError::Service(err),
|
||||
Either::Right(err) => DispatcherError::Encoder(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, U> Dispatcher<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
|
||||
U: Decoder + Encoder + 'static,
|
||||
<U as Encoder>::Item: 'static,
|
||||
{
|
||||
/// Construct new `Dispatcher` instance.
|
||||
pub fn new<T, F: IntoService<S>>(
|
||||
io: T,
|
||||
codec: U,
|
||||
state: State,
|
||||
service: F,
|
||||
timer: Timer,
|
||||
) -> Self
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
let io = Rc::new(RefCell::new(io));
|
||||
|
||||
// start support tasks
|
||||
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
|
||||
crate::rt::spawn(WriteTask::new(io, state.clone()));
|
||||
|
||||
Self::from_state(codec, state, service, timer)
|
||||
}
|
||||
|
||||
/// Construct new `Dispatcher` instance.
|
||||
pub fn from_state<F: IntoService<S>>(
|
||||
codec: U,
|
||||
state: State,
|
||||
service: F,
|
||||
timer: Timer,
|
||||
) -> Self {
|
||||
let updated = timer.now();
|
||||
let ka_timeout = Seconds(30);
|
||||
|
||||
// register keepalive timer
|
||||
let expire = updated + time::Duration::from(ka_timeout);
|
||||
timer.register(expire, expire, &state);
|
||||
|
||||
Dispatcher {
|
||||
service: service.into_service(),
|
||||
fut: None,
|
||||
inner: DispatcherInner {
|
||||
pool: state.memory_pool().pool(),
|
||||
ka_updated: Cell::new(updated),
|
||||
error: Cell::new(None),
|
||||
st: Cell::new(DispatcherState::Processing),
|
||||
shared: Rc::new(DispatcherShared {
|
||||
codec,
|
||||
error: Cell::new(None),
|
||||
inflight: Cell::new(0),
|
||||
}),
|
||||
state,
|
||||
timer,
|
||||
ka_timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Set keep-alive timeout.
|
||||
///
|
||||
/// To disable timeout set value to 0.
|
||||
///
|
||||
/// By default keep-alive timeout is set to 30 seconds.
|
||||
pub fn keepalive_timeout(mut self, timeout: Seconds) -> Self {
|
||||
// register keepalive timer
|
||||
let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka());
|
||||
if timeout.is_zero() {
|
||||
self.inner.timer.unregister(prev, &self.inner.state);
|
||||
} else {
|
||||
let expire = self.inner.ka_updated.get() + time::Duration::from(timeout);
|
||||
self.inner.timer.register(expire, prev, &self.inner.state);
|
||||
}
|
||||
self.inner.ka_timeout = timeout;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Set connection disconnect timeout in seconds.
|
||||
///
|
||||
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
|
||||
/// within this time, the connection get dropped.
|
||||
///
|
||||
/// To disable timeout set value to 0.
|
||||
///
|
||||
/// By default disconnect timeout is set to 1 seconds.
|
||||
pub fn disconnect_timeout(self, val: Seconds) -> Self {
|
||||
self.inner.state.set_disconnect_timeout(val);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, U> DispatcherShared<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
|
||||
S::Error: 'static,
|
||||
S::Future: 'static,
|
||||
U: Encoder + Decoder,
|
||||
<U as Encoder>::Item: 'static,
|
||||
{
|
||||
fn handle_result(&self, item: Result<S::Response, S::Error>, write: Write<'_>) {
|
||||
self.inflight.set(self.inflight.get() - 1);
|
||||
match write.encode_result(item, &self.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => write.enable_backpressure(None),
|
||||
Err(err) => self.error.set(Some(err.into())),
|
||||
}
|
||||
write.wake_dispatcher();
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, U> Future for Dispatcher<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
|
||||
U: Decoder + Encoder + 'static,
|
||||
<U as Encoder>::Item: 'static,
|
||||
{
|
||||
type Output = Result<(), S::Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.as_mut().project();
|
||||
let slf = &this.inner;
|
||||
let state = &slf.state;
|
||||
let read = state.read();
|
||||
let write = state.write();
|
||||
|
||||
// handle service response future
|
||||
if let Some(fut) = this.fut.as_mut().as_pin_mut() {
|
||||
match fut.poll(cx) {
|
||||
Poll::Pending => (),
|
||||
Poll::Ready(item) => {
|
||||
this.fut.set(None);
|
||||
slf.shared.inflight.set(slf.shared.inflight.get() - 1);
|
||||
slf.handle_result(item, write);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle memory pool pressure
|
||||
if slf.pool.poll_ready(cx).is_pending() {
|
||||
read.pause(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
loop {
|
||||
match slf.st.get() {
|
||||
DispatcherState::Processing => {
|
||||
let result = match slf.poll_service(this.service, cx, read) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(result) => result,
|
||||
};
|
||||
|
||||
let item = match result {
|
||||
PollService::Ready => {
|
||||
if !write.is_ready() {
|
||||
// instruct write task to notify dispatcher when data is flushed
|
||||
write.enable_backpressure(Some(cx.waker()));
|
||||
slf.st.set(DispatcherState::Backpressure);
|
||||
DispatchItem::WBackPressureEnabled
|
||||
} else if read.is_ready() {
|
||||
// decode incoming bytes if buffer is ready
|
||||
match read.decode(&slf.shared.codec) {
|
||||
Ok(Some(el)) => {
|
||||
slf.update_keepalive();
|
||||
DispatchItem::Item(el)
|
||||
}
|
||||
Ok(None) => {
|
||||
log::trace!("not enough data to decode next frame, register dispatch task");
|
||||
read.wake(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
Err(err) => {
|
||||
slf.st.set(DispatcherState::Stop);
|
||||
slf.unregister_keepalive();
|
||||
DispatchItem::DecoderError(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// no new events
|
||||
state.register_dispatcher(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
PollService::Item(item) => item,
|
||||
PollService::ServiceError => continue,
|
||||
};
|
||||
|
||||
// call service
|
||||
if this.fut.is_none() {
|
||||
// optimize first service call
|
||||
this.fut.set(Some(this.service.call(item)));
|
||||
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
|
||||
Poll::Ready(res) => {
|
||||
this.fut.set(None);
|
||||
slf.handle_result(res, write);
|
||||
}
|
||||
Poll::Pending => {
|
||||
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slf.spawn_service_call(this.service.call(item));
|
||||
}
|
||||
}
|
||||
// handle write back-pressure
|
||||
DispatcherState::Backpressure => {
|
||||
let result = match slf.poll_service(this.service, cx, read) {
|
||||
Poll::Ready(result) => result,
|
||||
Poll::Pending => return Poll::Pending,
|
||||
};
|
||||
let item = match result {
|
||||
PollService::Ready => {
|
||||
if write.is_ready() {
|
||||
slf.st.set(DispatcherState::Processing);
|
||||
DispatchItem::WBackPressureDisabled
|
||||
} else {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
PollService::Item(item) => item,
|
||||
PollService::ServiceError => continue,
|
||||
};
|
||||
|
||||
// call service
|
||||
if this.fut.is_none() {
|
||||
// optimize first service call
|
||||
this.fut.set(Some(this.service.call(item)));
|
||||
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
|
||||
Poll::Ready(res) => {
|
||||
this.fut.set(None);
|
||||
slf.handle_result(res, write);
|
||||
}
|
||||
Poll::Pending => {
|
||||
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slf.spawn_service_call(this.service.call(item));
|
||||
}
|
||||
}
|
||||
// drain service responses
|
||||
DispatcherState::Stop => {
|
||||
// service may relay on poll_ready for response results
|
||||
if !this.inner.state.is_dispatcher_ready_err() {
|
||||
let _ = this.service.poll_ready(cx);
|
||||
}
|
||||
|
||||
if slf.shared.inflight.get() == 0 {
|
||||
slf.st.set(DispatcherState::Shutdown);
|
||||
state.shutdown_io();
|
||||
} else {
|
||||
state.register_dispatcher(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
// shutdown service
|
||||
DispatcherState::Shutdown => {
|
||||
let err = slf.error.take();
|
||||
|
||||
return if this.service.poll_shutdown(cx, err.is_some()).is_ready() {
|
||||
log::trace!("service shutdown is completed, stop");
|
||||
|
||||
Poll::Ready(if let Some(err) = err {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(())
|
||||
})
|
||||
} else {
|
||||
slf.error.set(err);
|
||||
Poll::Pending
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, U> DispatcherInner<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
|
||||
U: Decoder + Encoder + 'static,
|
||||
{
|
||||
/// spawn service call
|
||||
fn spawn_service_call(&self, fut: S::Future) {
|
||||
self.shared.inflight.set(self.shared.inflight.get() + 1);
|
||||
|
||||
let st = self.state.clone();
|
||||
let shared = self.shared.clone();
|
||||
crate::rt::spawn(async move {
|
||||
let item = fut.await;
|
||||
shared.handle_result(item, st.write());
|
||||
});
|
||||
}
|
||||
|
||||
fn handle_result(
|
||||
&self,
|
||||
item: Result<Option<<U as Encoder>::Item>, S::Error>,
|
||||
write: Write<'_>,
|
||||
) {
|
||||
match write.encode_result(item, &self.shared.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => write.enable_backpressure(None),
|
||||
Err(Either::Left(err)) => {
|
||||
self.error.set(Some(err));
|
||||
}
|
||||
Err(Either::Right(err)) => {
|
||||
self.shared.error.set(Some(DispatcherError::Encoder(err)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_service(
|
||||
&self,
|
||||
srv: &S,
|
||||
cx: &mut Context<'_>,
|
||||
read: Read<'_>,
|
||||
) -> Poll<PollService<U>> {
|
||||
match srv.poll_ready(cx) {
|
||||
Poll::Ready(Ok(_)) => {
|
||||
// service is ready, wake io read task
|
||||
read.resume();
|
||||
|
||||
// check keepalive timeout
|
||||
self.check_keepalive();
|
||||
|
||||
// check for errors
|
||||
Poll::Ready(if let Some(err) = self.shared.error.take() {
|
||||
log::trace!("error occured, stopping dispatcher");
|
||||
self.unregister_keepalive();
|
||||
self.st.set(DispatcherState::Stop);
|
||||
|
||||
match err {
|
||||
DispatcherError::KeepAlive => {
|
||||
PollService::Item(DispatchItem::KeepAliveTimeout)
|
||||
}
|
||||
DispatcherError::Encoder(err) => {
|
||||
PollService::Item(DispatchItem::EncoderError(err))
|
||||
}
|
||||
DispatcherError::Service(err) => {
|
||||
self.error.set(Some(err));
|
||||
PollService::ServiceError
|
||||
}
|
||||
}
|
||||
} else if self.state.is_dispatcher_stopped() {
|
||||
log::trace!("dispatcher is instructed to stop");
|
||||
|
||||
self.unregister_keepalive();
|
||||
|
||||
// process unhandled data
|
||||
if let Ok(Some(el)) = read.decode(&self.shared.codec) {
|
||||
PollService::Item(DispatchItem::Item(el))
|
||||
} else {
|
||||
self.st.set(DispatcherState::Stop);
|
||||
|
||||
// get io error
|
||||
if let Some(err) = self.state.take_io_error() {
|
||||
PollService::Item(DispatchItem::IoError(err))
|
||||
} else {
|
||||
PollService::ServiceError
|
||||
}
|
||||
}
|
||||
} else {
|
||||
PollService::Ready
|
||||
})
|
||||
}
|
||||
// pause io read task
|
||||
Poll::Pending => {
|
||||
log::trace!("service is not ready, register dispatch task");
|
||||
read.pause(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
// handle service readiness error
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::trace!("service readiness check failed, stopping");
|
||||
self.st.set(DispatcherState::Stop);
|
||||
self.error.set(Some(err));
|
||||
self.unregister_keepalive();
|
||||
self.state.dispatcher_ready_err();
|
||||
Poll::Ready(PollService::ServiceError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ka(&self) -> Seconds {
|
||||
self.ka_timeout
|
||||
}
|
||||
|
||||
fn ka_enabled(&self) -> bool {
|
||||
self.ka_timeout.non_zero()
|
||||
}
|
||||
|
||||
/// check keepalive timeout
|
||||
fn check_keepalive(&self) {
|
||||
if self.state.is_keepalive() {
|
||||
log::trace!("keepalive timeout");
|
||||
if let Some(err) = self.shared.error.take() {
|
||||
self.shared.error.set(Some(err));
|
||||
} else {
|
||||
self.shared.error.set(Some(DispatcherError::KeepAlive));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// update keep-alive timer
|
||||
fn update_keepalive(&self) {
|
||||
if self.ka_enabled() {
|
||||
let updated = self.timer.now();
|
||||
if updated != self.ka_updated.get() {
|
||||
let ka = time::Duration::from(self.ka());
|
||||
self.timer.register(
|
||||
updated + ka,
|
||||
self.ka_updated.get() + ka,
|
||||
&self.state,
|
||||
);
|
||||
self.ka_updated.set(updated);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// unregister keep-alive timer
|
||||
fn unregister_keepalive(&self) {
|
||||
if self.ka_enabled() {
|
||||
self.timer.unregister(
|
||||
self.ka_updated.get() + time::Duration::from(self.ka()),
|
||||
&self.state,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::Rng;
|
||||
use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::codec::BytesCodec;
|
||||
use crate::testing::Io;
|
||||
use crate::time::{sleep, Millis};
|
||||
use crate::util::{Bytes, PoolRef, Ready};
|
||||
|
||||
use super::*;
|
||||
|
||||
impl<S, U> Dispatcher<S, U>
|
||||
where
|
||||
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
|
||||
S::Error: 'static,
|
||||
S::Future: 'static,
|
||||
U: Decoder + Encoder + 'static,
|
||||
<U as Encoder>::Item: 'static,
|
||||
{
|
||||
/// Construct new `Dispatcher` instance
|
||||
pub(crate) fn debug<T, F: IntoService<S>>(
|
||||
io: T,
|
||||
codec: U,
|
||||
service: F,
|
||||
) -> (Self, State)
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
let timer = Timer::default();
|
||||
let ka_timeout = Seconds(1);
|
||||
let ka_updated = timer.now();
|
||||
let state = State::new();
|
||||
let io = Rc::new(RefCell::new(io));
|
||||
let shared = Rc::new(DispatcherShared {
|
||||
codec: codec,
|
||||
error: Cell::new(None),
|
||||
inflight: Cell::new(0),
|
||||
});
|
||||
|
||||
let expire = ka_updated + Duration::from_millis(500);
|
||||
timer.register(expire, expire, &state);
|
||||
|
||||
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
|
||||
crate::rt::spawn(WriteTask::new(io.clone(), state.clone()));
|
||||
|
||||
(
|
||||
Dispatcher {
|
||||
service: service.into_service(),
|
||||
fut: None,
|
||||
inner: DispatcherInner {
|
||||
shared,
|
||||
timer,
|
||||
ka_timeout,
|
||||
ka_updated: Cell::new(ka_updated),
|
||||
state: state.clone(),
|
||||
error: Cell::new(None),
|
||||
st: Cell::new(DispatcherState::Processing),
|
||||
pool: state.memory_pool().pool(),
|
||||
},
|
||||
},
|
||||
state,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[crate::rt_test]
|
||||
async fn test_basic() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let (disp, _) = Dispatcher::debug(
|
||||
server,
|
||||
BytesCodec,
|
||||
crate::service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
|
||||
sleep(Millis(50)).await;
|
||||
if let DispatchItem::Item(msg) = msg {
|
||||
Ok::<_, ()>(Some(msg.freeze()))
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}),
|
||||
);
|
||||
crate::rt::spawn(async move {
|
||||
let _ = disp.await;
|
||||
});
|
||||
|
||||
sleep(Millis(25)).await;
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
|
||||
client.close().await;
|
||||
assert!(client.is_server_dropped());
|
||||
}
|
||||
|
||||
#[crate::rt_test]
|
||||
async fn test_sink() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let (disp, st) = Dispatcher::debug(
|
||||
server,
|
||||
BytesCodec,
|
||||
crate::service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
|
||||
if let DispatchItem::Item(msg) = msg {
|
||||
Ok::<_, ()>(Some(msg.freeze()))
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}),
|
||||
);
|
||||
crate::rt::spawn(async move {
|
||||
let _ = disp.disconnect_timeout(Seconds(1)).await;
|
||||
});
|
||||
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
|
||||
assert!(st
|
||||
.write()
|
||||
.encode(Bytes::from_static(b"test"), &mut BytesCodec)
|
||||
.is_ok());
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"test"));
|
||||
|
||||
st.close();
|
||||
sleep(Millis(1100)).await;
|
||||
assert!(client.is_server_dropped());
|
||||
}
|
||||
|
||||
#[crate::rt_test]
|
||||
async fn test_err_in_service() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(0);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let (disp, state) = Dispatcher::debug(
|
||||
server,
|
||||
BytesCodec,
|
||||
crate::service::fn_service(|_: DispatchItem<BytesCodec>| async move {
|
||||
Err::<Option<Bytes>, _>(())
|
||||
}),
|
||||
);
|
||||
state
|
||||
.write()
|
||||
.encode(
|
||||
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
|
||||
&mut BytesCodec,
|
||||
)
|
||||
.unwrap();
|
||||
crate::rt::spawn(async move {
|
||||
let _ = disp.await;
|
||||
});
|
||||
|
||||
// buffer should be flushed
|
||||
client.remote_buffer_cap(1024);
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
|
||||
// write side must be closed, dispatcher waiting for read side to close
|
||||
assert!(client.is_closed());
|
||||
|
||||
// close read side
|
||||
client.close().await;
|
||||
assert!(client.is_server_dropped());
|
||||
}
|
||||
|
||||
#[crate::rt_test]
|
||||
async fn test_err_in_service_ready() {
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(0);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let counter = Rc::new(Cell::new(0));
|
||||
|
||||
struct Srv(Rc<Cell<usize>>);
|
||||
|
||||
impl Service for Srv {
|
||||
type Request = DispatchItem<BytesCodec>;
|
||||
type Response = Option<Response<BytesCodec>>;
|
||||
type Error = ();
|
||||
type Future = Ready<Option<Response<BytesCodec>>, ()>;
|
||||
|
||||
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
|
||||
self.0.set(self.0.get() + 1);
|
||||
Poll::Ready(Err(()))
|
||||
}
|
||||
|
||||
fn call(&self, _: DispatchItem<BytesCodec>) -> Self::Future {
|
||||
Ready::Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone()));
|
||||
state
|
||||
.write()
|
||||
.encode(
|
||||
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
|
||||
&mut BytesCodec,
|
||||
)
|
||||
.unwrap();
|
||||
crate::rt::spawn(async move {
|
||||
let _ = disp.await;
|
||||
});
|
||||
|
||||
// buffer should be flushed
|
||||
client.remote_buffer_cap(1024);
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
|
||||
// write side must be closed, dispatcher waiting for read side to close
|
||||
assert!(client.is_closed());
|
||||
|
||||
// close read side
|
||||
client.close().await;
|
||||
assert!(client.is_server_dropped());
|
||||
|
||||
// service must be checked for readiness only once
|
||||
assert_eq!(counter.get(), 1);
|
||||
}
|
||||
|
||||
#[crate::rt_test]
|
||||
async fn test_write_backpressure() {
|
||||
let (client, server) = Io::create();
|
||||
// do not allow to write to socket
|
||||
client.remote_buffer_cap(0);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
|
||||
let data2 = data.clone();
|
||||
|
||||
let (disp, state) = Dispatcher::debug(
|
||||
server,
|
||||
BytesCodec,
|
||||
crate::service::fn_service(move |msg: DispatchItem<BytesCodec>| {
|
||||
let data = data2.clone();
|
||||
async move {
|
||||
match msg {
|
||||
DispatchItem::Item(_) => {
|
||||
data.lock().unwrap().borrow_mut().push(0);
|
||||
let bytes = rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(65_536)
|
||||
.map(char::from)
|
||||
.collect::<String>();
|
||||
return Ok::<_, ()>(Some(Bytes::from(bytes)));
|
||||
}
|
||||
DispatchItem::WBackPressureEnabled => {
|
||||
data.lock().unwrap().borrow_mut().push(1);
|
||||
}
|
||||
DispatchItem::WBackPressureDisabled => {
|
||||
data.lock().unwrap().borrow_mut().push(2);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}),
|
||||
);
|
||||
let pool = PoolRef::default();
|
||||
pool.set_read_params(8 * 1024, 1024);
|
||||
pool.set_write_params(16 * 1024, 1024);
|
||||
crate::rt::spawn(async move {
|
||||
let _ = disp.await;
|
||||
});
|
||||
|
||||
let buf = client.read_any();
|
||||
assert_eq!(buf, Bytes::from_static(b""));
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
sleep(Millis(25)).await;
|
||||
|
||||
// buf must be consumed
|
||||
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
|
||||
|
||||
// response message
|
||||
assert!(!state.write().is_ready());
|
||||
assert_eq!(state.write().with_buf(|buf| buf.len()), 65536);
|
||||
|
||||
client.remote_buffer_cap(10240);
|
||||
sleep(Millis(50)).await;
|
||||
assert_eq!(state.write().with_buf(|buf| buf.len()), 55296);
|
||||
|
||||
client.remote_buffer_cap(45056);
|
||||
sleep(Millis(50)).await;
|
||||
assert_eq!(state.write().with_buf(|buf| buf.len()), 10240);
|
||||
|
||||
// backpressure disabled
|
||||
assert!(state.write().is_ready());
|
||||
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
|
||||
}
|
||||
|
||||
#[crate::rt_test]
|
||||
async fn test_keepalive() {
|
||||
let (client, server) = Io::create();
|
||||
// do not allow to write to socket
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
|
||||
let data2 = data.clone();
|
||||
|
||||
let (disp, state) = Dispatcher::debug(
|
||||
server,
|
||||
BytesCodec,
|
||||
crate::service::fn_service(move |msg: DispatchItem<BytesCodec>| {
|
||||
let data = data2.clone();
|
||||
async move {
|
||||
match msg {
|
||||
DispatchItem::Item(bytes) => {
|
||||
data.lock().unwrap().borrow_mut().push(0);
|
||||
return Ok::<_, ()>(Some(bytes.freeze()));
|
||||
}
|
||||
DispatchItem::KeepAliveTimeout => {
|
||||
data.lock().unwrap().borrow_mut().push(1);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}),
|
||||
);
|
||||
crate::rt::spawn(async move {
|
||||
let _ = disp
|
||||
.keepalive_timeout(Seconds::ZERO)
|
||||
.keepalive_timeout(Seconds(1))
|
||||
.await;
|
||||
});
|
||||
|
||||
state.set_disconnect_timeout(Seconds(1));
|
||||
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
sleep(Millis(3500)).await;
|
||||
|
||||
// write side must be closed, dispatcher should fail with keep-alive
|
||||
let flags = state.flags();
|
||||
assert!(state.is_io_err());
|
||||
assert!(state.is_io_shutdown());
|
||||
assert!(flags.contains(crate::framed::state::Flags::IO_SHUTDOWN));
|
||||
assert!(client.is_closed());
|
||||
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
|
||||
}
|
||||
|
||||
#[crate::rt_test]
|
||||
async fn test_unhandled_data() {
|
||||
let handled = Arc::new(AtomicBool::new(false));
|
||||
let handled2 = handled.clone();
|
||||
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
let (disp, _) = Dispatcher::debug(
|
||||
server,
|
||||
BytesCodec,
|
||||
crate::service::fn_service(move |msg: DispatchItem<BytesCodec>| {
|
||||
handled2.store(true, Relaxed);
|
||||
async move {
|
||||
sleep(Millis(50)).await;
|
||||
if let DispatchItem::Item(msg) = msg {
|
||||
Ok::<_, ()>(Some(msg.freeze()))
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
client.close().await;
|
||||
crate::rt::spawn(async move {
|
||||
let _ = disp.await;
|
||||
});
|
||||
sleep(Millis(50)).await;
|
||||
|
||||
assert!(handled.load(Relaxed));
|
||||
}
|
||||
}
|
|
@ -1,89 +0,0 @@
|
|||
use std::{fmt, io};
|
||||
|
||||
mod dispatcher;
|
||||
mod read;
|
||||
mod state;
|
||||
mod time;
|
||||
mod write;
|
||||
|
||||
pub use self::dispatcher::Dispatcher;
|
||||
pub use self::read::ReadTask;
|
||||
pub use self::state::{OnDisconnect, Read, State, Write};
|
||||
pub use self::time::Timer;
|
||||
pub use self::write::WriteTask;
|
||||
|
||||
use crate::codec::{Decoder, Encoder};
|
||||
|
||||
/// Framed transport item
|
||||
pub enum DispatchItem<U: Encoder + Decoder> {
|
||||
Item(<U as Decoder>::Item),
|
||||
/// Write back-pressure enabled
|
||||
WBackPressureEnabled,
|
||||
/// Write back-pressure disabled
|
||||
WBackPressureDisabled,
|
||||
/// Keep alive timeout
|
||||
KeepAliveTimeout,
|
||||
/// Decoder parse error
|
||||
DecoderError(<U as Decoder>::Error),
|
||||
/// Encoder parse error
|
||||
EncoderError(<U as Encoder>::Error),
|
||||
/// Unexpected io error
|
||||
IoError(io::Error),
|
||||
}
|
||||
|
||||
impl<U> fmt::Debug for DispatchItem<U>
|
||||
where
|
||||
U: Encoder + Decoder,
|
||||
<U as Decoder>::Item: fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match *self {
|
||||
DispatchItem::Item(ref item) => {
|
||||
write!(fmt, "DispatchItem::Item({:?})", item)
|
||||
}
|
||||
DispatchItem::WBackPressureEnabled => {
|
||||
write!(fmt, "DispatchItem::WBackPressureEnabled")
|
||||
}
|
||||
DispatchItem::WBackPressureDisabled => {
|
||||
write!(fmt, "DispatchItem::WBackPressureDisabled")
|
||||
}
|
||||
DispatchItem::KeepAliveTimeout => {
|
||||
write!(fmt, "DispatchItem::KeepAliveTimeout")
|
||||
}
|
||||
DispatchItem::EncoderError(ref e) => {
|
||||
write!(fmt, "DispatchItem::EncoderError({:?})", e)
|
||||
}
|
||||
DispatchItem::DecoderError(ref e) => {
|
||||
write!(fmt, "DispatchItem::DecoderError({:?})", e)
|
||||
}
|
||||
DispatchItem::IoError(ref e) => {
|
||||
write!(fmt, "DispatchItem::IoError({:?})", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codec::BytesCodec;
|
||||
|
||||
#[test]
|
||||
fn test_fmt() {
|
||||
type T = DispatchItem<BytesCodec>;
|
||||
|
||||
let err = T::EncoderError(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatchItem::Encoder"));
|
||||
let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatchItem::Decoder"));
|
||||
let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatchItem::IoError"));
|
||||
|
||||
assert!(format!("{:?}", T::WBackPressureEnabled)
|
||||
.contains("DispatchItem::WBackPressureEnabled"));
|
||||
assert!(format!("{:?}", T::WBackPressureDisabled)
|
||||
.contains("DispatchItem::WBackPressureDisabled"));
|
||||
assert!(format!("{:?}", T::KeepAliveTimeout)
|
||||
.contains("DispatchItem::KeepAliveTimeout"));
|
||||
}
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
use std::{cell::RefCell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite};
|
||||
use crate::framed::State;
|
||||
|
||||
/// Read io task
|
||||
pub struct ReadTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
io: Rc<RefCell<T>>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl<T> ReadTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
/// Create new read io task
|
||||
pub fn new(io: Rc<RefCell<T>>, state: State) -> Self {
|
||||
Self { io, state }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for ReadTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if self.state.is_io_shutdown() {
|
||||
log::trace!("read task is instructed to shutdown");
|
||||
Poll::Ready(())
|
||||
} else if self.state.is_io_stop() {
|
||||
self.state.wake_dispatcher();
|
||||
Poll::Ready(())
|
||||
} else if self.state.is_read_paused() {
|
||||
self.state.register_read_task(cx.waker());
|
||||
Poll::Pending
|
||||
} else {
|
||||
let mut io = self.io.borrow_mut();
|
||||
if self.state.read_io(&mut *io, cx) {
|
||||
Poll::Pending
|
||||
} else {
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,115 +0,0 @@
|
|||
use std::{cell::RefCell, collections::BTreeMap, rc::Rc, time::Instant};
|
||||
|
||||
use crate::framed::State;
|
||||
use crate::time::{now, sleep, Millis};
|
||||
use crate::util::HashSet;
|
||||
|
||||
pub struct Timer(Rc<RefCell<Inner>>);
|
||||
|
||||
struct Inner {
|
||||
resolution: Millis,
|
||||
current: Option<Instant>,
|
||||
notifications: BTreeMap<Instant, HashSet<State>>,
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
fn new(resolution: Millis) -> Self {
|
||||
Inner {
|
||||
resolution,
|
||||
current: None,
|
||||
notifications: BTreeMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn unregister(&mut self, expire: Instant, state: &State) {
|
||||
if let Some(ref mut states) = self.notifications.get_mut(&expire) {
|
||||
states.remove(state);
|
||||
if states.is_empty() {
|
||||
self.notifications.remove(&expire);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Timer {
|
||||
fn clone(&self) -> Self {
|
||||
Timer(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Timer {
|
||||
fn default() -> Self {
|
||||
Timer::new(Millis::ONE_SEC)
|
||||
}
|
||||
}
|
||||
|
||||
impl Timer {
|
||||
/// Create new timer with resolution in milliseconds
|
||||
pub fn new(resolution: Millis) -> Timer {
|
||||
Timer(Rc::new(RefCell::new(Inner::new(resolution))))
|
||||
}
|
||||
|
||||
pub fn register(&self, expire: Instant, previous: Instant, state: &State) {
|
||||
{
|
||||
let mut inner = self.0.borrow_mut();
|
||||
|
||||
inner.unregister(previous, state);
|
||||
inner
|
||||
.notifications
|
||||
.entry(expire)
|
||||
.or_insert_with(HashSet::default)
|
||||
.insert(state.clone());
|
||||
}
|
||||
|
||||
let _ = self.now();
|
||||
}
|
||||
|
||||
pub fn unregister(&self, expire: Instant, state: &State) {
|
||||
self.0.borrow_mut().unregister(expire, state);
|
||||
}
|
||||
|
||||
/// Get current time. This function has to be called from
|
||||
/// future's poll method, otherwise it panics.
|
||||
pub fn now(&self) -> Instant {
|
||||
let cur = self.0.borrow().current;
|
||||
if let Some(cur) = cur {
|
||||
cur
|
||||
} else {
|
||||
let now_val = now();
|
||||
let inner = self.0.clone();
|
||||
let interval = {
|
||||
let mut b = inner.borrow_mut();
|
||||
b.current = Some(now_val);
|
||||
b.resolution
|
||||
};
|
||||
|
||||
crate::rt::spawn(async move {
|
||||
sleep(interval).await;
|
||||
let empty = {
|
||||
let mut i = inner.borrow_mut();
|
||||
let now = i.current.take().unwrap_or_else(now);
|
||||
|
||||
// notify io dispatcher
|
||||
while let Some(key) = i.notifications.keys().next() {
|
||||
let key = *key;
|
||||
if key <= now {
|
||||
for st in i.notifications.remove(&key).unwrap() {
|
||||
st.keepalive_timeout();
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
i.notifications.is_empty()
|
||||
};
|
||||
|
||||
// extra tick
|
||||
if !empty {
|
||||
let _ = Timer(inner).now();
|
||||
}
|
||||
});
|
||||
|
||||
now_val
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,166 +0,0 @@
|
|||
use std::task::{Context, Poll};
|
||||
use std::{cell::RefCell, future::Future, pin::Pin, rc::Rc};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use crate::framed::State;
|
||||
use crate::time::{sleep, Sleep};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum IoWriteState {
|
||||
Processing,
|
||||
Shutdown(Option<Sleep>, Shutdown),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Shutdown {
|
||||
None,
|
||||
Flushed,
|
||||
Stopping,
|
||||
}
|
||||
|
||||
/// Write io task
|
||||
pub struct WriteTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
st: IoWriteState,
|
||||
io: Rc<RefCell<T>>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl<T> WriteTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
/// Create new write io task
|
||||
pub fn new(io: Rc<RefCell<T>>, state: State) -> Self {
|
||||
Self {
|
||||
io,
|
||||
state,
|
||||
st: IoWriteState::Processing,
|
||||
}
|
||||
}
|
||||
|
||||
/// Shutdown io stream
|
||||
pub fn shutdown(io: Rc<RefCell<T>>, state: State) -> Self {
|
||||
let disconnect_timeout = state.get_disconnect_timeout();
|
||||
let st = IoWriteState::Shutdown(disconnect_timeout.map(sleep), Shutdown::None);
|
||||
|
||||
Self { st, io, state }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for WriteTask<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.as_mut().get_mut();
|
||||
|
||||
// IO error occured
|
||||
if this.state.is_io_err() {
|
||||
log::trace!("write io is closed");
|
||||
return Poll::Ready(());
|
||||
} else if this.state.is_io_stop() {
|
||||
self.state.wake_dispatcher();
|
||||
return Poll::Ready(());
|
||||
}
|
||||
|
||||
match this.st {
|
||||
IoWriteState::Processing => {
|
||||
if this.state.is_io_shutdown() {
|
||||
log::trace!("write task is instructed to shutdown");
|
||||
|
||||
let disconnect_timeout = this.state.get_disconnect_timeout();
|
||||
this.st = IoWriteState::Shutdown(
|
||||
disconnect_timeout.map(sleep),
|
||||
Shutdown::None,
|
||||
);
|
||||
return self.poll(cx);
|
||||
}
|
||||
|
||||
// flush framed instance
|
||||
match this.state.flush_io(&mut *this.io.borrow_mut(), cx) {
|
||||
Poll::Pending | Poll::Ready(true) => Poll::Pending,
|
||||
Poll::Ready(false) => Poll::Ready(()),
|
||||
}
|
||||
}
|
||||
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
|
||||
// close WRITE side and wait for disconnect on read side.
|
||||
// use disconnect timeout, otherwise it could hang forever.
|
||||
loop {
|
||||
match st {
|
||||
Shutdown::None => {
|
||||
// flush write buffer
|
||||
let result =
|
||||
this.state.flush_io(&mut *this.io.borrow_mut(), cx);
|
||||
match result {
|
||||
Poll::Ready(true) => {
|
||||
*st = Shutdown::Flushed;
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(false) => {
|
||||
this.state.set_wr_shutdown_complete();
|
||||
log::trace!(
|
||||
"write task is closed with err during flush"
|
||||
);
|
||||
return Poll::Ready(());
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
Shutdown::Flushed => {
|
||||
// shutdown WRITE side
|
||||
match Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx)
|
||||
{
|
||||
Poll::Ready(Ok(_)) => {
|
||||
*st = Shutdown::Stopping;
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Err(_)) => {
|
||||
this.state.set_wr_shutdown_complete();
|
||||
log::trace!(
|
||||
"write task is closed with err during shutdown"
|
||||
);
|
||||
return Poll::Ready(());
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
Shutdown::Stopping => {
|
||||
// read until 0 or err
|
||||
let mut buf = [0u8; 512];
|
||||
let mut io = this.io.borrow_mut();
|
||||
loop {
|
||||
let mut read_buf = ReadBuf::new(&mut buf);
|
||||
match Pin::new(&mut *io).poll_read(cx, &mut read_buf) {
|
||||
Poll::Ready(Err(_)) | Poll::Ready(Ok(_))
|
||||
if read_buf.filled().is_empty() =>
|
||||
{
|
||||
this.state.set_wr_shutdown_complete();
|
||||
log::trace!("write task is stopped");
|
||||
return Poll::Ready(());
|
||||
}
|
||||
Poll::Pending => break,
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// disconnect timeout
|
||||
if let Some(ref delay) = delay {
|
||||
if delay.poll_elapsed(cx).is_pending() {
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
this.state.set_wr_shutdown_complete();
|
||||
log::trace!("write task is stopped after delay");
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,11 +1,11 @@
|
|||
use std::{cell::RefCell, error::Error, fmt, marker::PhantomData, rc::Rc};
|
||||
|
||||
use crate::framed::State;
|
||||
use crate::http::body::MessageBody;
|
||||
use crate::http::config::{KeepAlive, OnRequest, ServiceConfig};
|
||||
use crate::http::error::ResponseError;
|
||||
use crate::http::h1::{Codec, ExpectHandler, H1Service, UpgradeHandler};
|
||||
use crate::http::h2::H2Service;
|
||||
use crate::io::{Filter, Io, IoRef};
|
||||
// use crate::http::h2::H2Service;
|
||||
use crate::http::helpers::{Data, DataFactory};
|
||||
use crate::http::request::Request;
|
||||
use crate::http::response::Response;
|
||||
|
@ -18,7 +18,7 @@ use crate::util::PoolId;
|
|||
///
|
||||
/// This type can be used to construct an instance of `http service` through a
|
||||
/// builder-like pattern.
|
||||
pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
|
||||
pub struct HttpServiceBuilder<F, S, X = ExpectHandler, U = UpgradeHandler<F>> {
|
||||
keep_alive: KeepAlive,
|
||||
client_timeout: Millis,
|
||||
client_disconnect: Seconds,
|
||||
|
@ -26,12 +26,11 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
|
|||
pool: PoolId,
|
||||
expect: X,
|
||||
upgrade: Option<U>,
|
||||
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
|
||||
on_request: Option<OnRequest<T>>,
|
||||
_t: PhantomData<(T, S)>,
|
||||
on_request: Option<OnRequest>,
|
||||
_t: PhantomData<(F, S)>,
|
||||
}
|
||||
|
||||
impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>> {
|
||||
impl<F, S> HttpServiceBuilder<F, S, ExpectHandler, UpgradeHandler<F>> {
|
||||
/// Create instance of `ServiceConfigBuilder`
|
||||
pub fn new() -> Self {
|
||||
HttpServiceBuilder {
|
||||
|
@ -42,15 +41,15 @@ impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>> {
|
|||
pool: PoolId::P1,
|
||||
expect: ExpectHandler,
|
||||
upgrade: None,
|
||||
on_connect: None,
|
||||
on_request: None,
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, X, U> HttpServiceBuilder<T, S, X, U>
|
||||
impl<F, S, X, U> HttpServiceBuilder<F, S, X, U>
|
||||
where
|
||||
F: Filter + 'static,
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
|
@ -61,7 +60,7 @@ where
|
|||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<Config = (), Request = (Request, T, State, Codec), Response = ()>,
|
||||
U: ServiceFactory<Config = (), Request = (Request, Io<F>, Codec), Response = ()>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
|
@ -122,29 +121,14 @@ where
|
|||
self
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[deprecated(since = "0.4.12", note = "Use memory pool config")]
|
||||
#[inline]
|
||||
/// Set read/write buffer params
|
||||
///
|
||||
/// By default read buffer is 8kb, write buffer is 8kb
|
||||
pub fn buffer_params(
|
||||
self,
|
||||
_max_read_buf_size: u16,
|
||||
_max_write_buf_size: u16,
|
||||
_min_buf_size: u16,
|
||||
) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide service for `EXPECT: 100-Continue` support.
|
||||
///
|
||||
/// Service get called with request that contains `EXPECT` header.
|
||||
/// Service must return request in case of success, in that case
|
||||
/// request will be forwarded to main service.
|
||||
pub fn expect<F, X1>(self, expect: F) -> HttpServiceBuilder<T, S, X1, U>
|
||||
pub fn expect<XF, X1>(self, expect: XF) -> HttpServiceBuilder<F, S, X1, U>
|
||||
where
|
||||
F: IntoServiceFactory<X1>,
|
||||
XF: IntoServiceFactory<X1>,
|
||||
X1: ServiceFactory<Config = (), Request = Request, Response = Request>,
|
||||
X1::Error: ResponseError + 'static,
|
||||
X1::InitError: fmt::Debug,
|
||||
|
@ -159,7 +143,6 @@ where
|
|||
pool: self.pool,
|
||||
expect: expect.into_factory(),
|
||||
upgrade: self.upgrade,
|
||||
on_connect: self.on_connect,
|
||||
on_request: self.on_request,
|
||||
_t: PhantomData,
|
||||
}
|
||||
|
@ -169,12 +152,12 @@ where
|
|||
///
|
||||
/// If service is provided then normal requests handling get halted
|
||||
/// and this service get called with original request and framed object.
|
||||
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
|
||||
pub fn upgrade<UF, U1>(self, upgrade: UF) -> HttpServiceBuilder<F, S, X, U1>
|
||||
where
|
||||
F: IntoServiceFactory<U1>,
|
||||
UF: IntoServiceFactory<U1>,
|
||||
U1: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, T, State, Codec),
|
||||
Request = (Request, Io<F>, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U1::Error: fmt::Display + Error + 'static,
|
||||
|
@ -190,46 +173,29 @@ where
|
|||
pool: self.pool,
|
||||
expect: self.expect,
|
||||
upgrade: Some(upgrade.into_factory()),
|
||||
on_connect: self.on_connect,
|
||||
on_request: self.on_request,
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set on-connect callback.
|
||||
///
|
||||
/// It get called once per connection and result of the call
|
||||
/// get stored to the request's extensions.
|
||||
pub fn on_connect<F, I>(mut self, f: F) -> Self
|
||||
where
|
||||
F: Fn(&T) -> I + 'static,
|
||||
I: Clone + 'static,
|
||||
{
|
||||
self.on_connect = Some(Rc::new(move |io| Box::new(Data(f(io)))));
|
||||
self
|
||||
}
|
||||
|
||||
/// Set req request callback.
|
||||
///
|
||||
/// It get called once per request.
|
||||
pub fn on_request<Filter, F>(mut self, f: F) -> Self
|
||||
pub fn on_request<R, FR>(mut self, f: FR) -> Self
|
||||
where
|
||||
F: IntoService<Filter>,
|
||||
Filter: Service<
|
||||
Request = (Request, Rc<RefCell<T>>),
|
||||
Response = Request,
|
||||
Error = Response,
|
||||
> + 'static,
|
||||
FR: IntoService<R>,
|
||||
R: Service<Request = (Request, IoRef), Response = Request, Error = Response>
|
||||
+ 'static,
|
||||
{
|
||||
self.on_request = Some(boxed::service(f.into_service()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Finish service configuration and create *http service* for HTTP/1 protocol.
|
||||
pub fn h1<F, B>(self, service: F) -> H1Service<T, S, B, X, U>
|
||||
pub fn h1<B, SF>(self, service: SF) -> H1Service<F, S, B, X, U>
|
||||
where
|
||||
B: MessageBody,
|
||||
F: IntoServiceFactory<S>,
|
||||
SF: IntoServiceFactory<S>,
|
||||
S::Error: ResponseError,
|
||||
S::InitError: fmt::Debug,
|
||||
S::Response: Into<Response<B>>,
|
||||
|
@ -245,15 +211,15 @@ where
|
|||
H1Service::with_config(cfg, service.into_factory())
|
||||
.expect(self.expect)
|
||||
.upgrade(self.upgrade)
|
||||
.on_connect(self.on_connect)
|
||||
.on_request(self.on_request)
|
||||
}
|
||||
|
||||
// pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B>
|
||||
/// Finish service configuration and create *http service* for HTTP/2 protocol.
|
||||
pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B>
|
||||
pub fn h2<B, SF>(self, service: SF) -> H1Service<F, S, B, X, U>
|
||||
where
|
||||
B: MessageBody + 'static,
|
||||
F: IntoServiceFactory<S>,
|
||||
SF: IntoServiceFactory<S>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
S::Response: Into<Response<B>> + 'static,
|
||||
|
@ -266,14 +232,19 @@ where
|
|||
self.handshake_timeout,
|
||||
self.pool,
|
||||
);
|
||||
H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect)
|
||||
|
||||
// H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect)
|
||||
H1Service::with_config(cfg, service.into_factory())
|
||||
.expect(self.expect)
|
||||
.upgrade(self.upgrade)
|
||||
.on_request(self.on_request)
|
||||
}
|
||||
|
||||
/// Finish service configuration and create `HttpService` instance.
|
||||
pub fn finish<F, B>(self, service: F) -> HttpService<T, S, B, X, U>
|
||||
pub fn finish<B, SF>(self, service: SF) -> HttpService<F, S, B, X, U>
|
||||
where
|
||||
B: MessageBody + 'static,
|
||||
F: IntoServiceFactory<S>,
|
||||
SF: IntoServiceFactory<S>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
S::Response: Into<Response<B>> + 'static,
|
||||
|
@ -290,7 +261,6 @@ where
|
|||
HttpService::with_config(cfg, service.into_factory())
|
||||
.expect(self.expect)
|
||||
.upgrade(self.upgrade)
|
||||
.on_connect(self.on_connect)
|
||||
.on_request(self.on_request)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,10 +42,8 @@ impl ClientBuilder {
|
|||
/// Use custom connector service.
|
||||
pub fn connector<T>(mut self, connector: T) -> Self
|
||||
where
|
||||
T: Service<Request = Connect, Error = ConnectError> + 'static,
|
||||
T::Response: Connection,
|
||||
<T::Response as Connection>::Future: 'static,
|
||||
T::Future: 'static,
|
||||
T: Service<Request = Connect, Response = Connection, Error = ConnectError>
|
||||
+ 'static,
|
||||
{
|
||||
self.config.connector = Box::new(ConnectorWrapper(connector));
|
||||
self
|
||||
|
|
|
@ -1,15 +1,23 @@
|
|||
use std::{fmt, future::Future, io, net, pin::Pin, task::Context, task::Poll};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
|
||||
use crate::http::body::Body;
|
||||
use crate::http::h1::ClientCodec;
|
||||
use crate::http::{RequestHeadType, ResponseHead};
|
||||
use crate::io::IoBoxed;
|
||||
use crate::Service;
|
||||
|
||||
use super::error::{ConnectError, SendRequestError};
|
||||
use super::response::ClientResponse;
|
||||
use super::{Connect as ClientConnect, Connection};
|
||||
|
||||
pub(crate) type TunnelFuture = Pin<
|
||||
Box<
|
||||
dyn Future<
|
||||
Output = Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError>,
|
||||
>,
|
||||
>,
|
||||
>;
|
||||
|
||||
pub(super) struct ConnectorWrapper<T>(pub(crate) T);
|
||||
|
||||
pub(super) trait Connect {
|
||||
|
@ -25,25 +33,12 @@ pub(super) trait Connect {
|
|||
&self,
|
||||
head: RequestHeadType,
|
||||
addr: Option<net::SocketAddr>,
|
||||
) -> Pin<
|
||||
Box<
|
||||
dyn Future<
|
||||
Output = Result<
|
||||
(ResponseHead, Framed<BoxedSocket, ClientCodec>),
|
||||
SendRequestError,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
>;
|
||||
) -> TunnelFuture;
|
||||
}
|
||||
|
||||
impl<T> Connect for ConnectorWrapper<T>
|
||||
where
|
||||
T: Service<Request = ClientConnect, Error = ConnectError>,
|
||||
T::Response: Connection,
|
||||
<T::Response as Connection>::Io: 'static,
|
||||
<T::Response as Connection>::Future: 'static,
|
||||
<T::Response as Connection>::TunnelFuture: 'static,
|
||||
T: Service<Request = ClientConnect, Response = Connection, Error = ConnectError>,
|
||||
T::Future: 'static,
|
||||
{
|
||||
fn send_request(
|
||||
|
@ -73,16 +68,7 @@ where
|
|||
&self,
|
||||
head: RequestHeadType,
|
||||
addr: Option<net::SocketAddr>,
|
||||
) -> Pin<
|
||||
Box<
|
||||
dyn Future<
|
||||
Output = Result<
|
||||
(ResponseHead, Framed<BoxedSocket, ClientCodec>),
|
||||
SendRequestError,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
> {
|
||||
) -> TunnelFuture {
|
||||
// connect to the host
|
||||
let fut = self.0.call(ClientConnect {
|
||||
uri: head.as_ref().uri.clone(),
|
||||
|
@ -93,69 +79,7 @@ where
|
|||
let connection = fut.await?;
|
||||
|
||||
// send request
|
||||
let (head, framed) = connection.open_tunnel(head).await?;
|
||||
|
||||
let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io))));
|
||||
Ok((head, framed))
|
||||
connection.open_tunnel(head).await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
trait AsyncSocket {
|
||||
fn as_read(&self) -> &(dyn AsyncRead + Unpin);
|
||||
fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin);
|
||||
fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin);
|
||||
}
|
||||
|
||||
struct Socket<T: AsyncRead + AsyncWrite + Unpin>(T);
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncSocket for Socket<T> {
|
||||
fn as_read(&self) -> &(dyn AsyncRead + Unpin) {
|
||||
&self.0
|
||||
}
|
||||
fn as_read_mut(&mut self) -> &mut (dyn AsyncRead + Unpin) {
|
||||
&mut self.0
|
||||
}
|
||||
fn as_write(&mut self) -> &mut (dyn AsyncWrite + Unpin) {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BoxedSocket(Box<dyn AsyncSocket>);
|
||||
|
||||
impl fmt::Debug for BoxedSocket {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "BoxedSocket")
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for BoxedSocket {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for BoxedSocket {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(self.get_mut().0.as_write()).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(self.get_mut().0.as_write()).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(self.get_mut().0.as_write()).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,79 +8,43 @@ use crate::http::h1::ClientCodec;
|
|||
use crate::http::message::{RequestHeadType, ResponseHead};
|
||||
use crate::http::payload::Payload;
|
||||
use crate::http::Protocol;
|
||||
use crate::io::IoBoxed;
|
||||
use crate::util::{Bytes, Either, Ready};
|
||||
|
||||
use super::error::SendRequestError;
|
||||
use super::pool::Acquired;
|
||||
use super::{h1proto, h2proto};
|
||||
|
||||
pub(super) enum ConnectionType<Io> {
|
||||
H1(Io),
|
||||
pub(super) enum ConnectionType {
|
||||
H1(IoBoxed),
|
||||
H2(SendRequest<Bytes>),
|
||||
}
|
||||
|
||||
pub trait Connection {
|
||||
type Io: AsyncRead + AsyncWrite + Unpin;
|
||||
type Future: Future<Output = Result<(ResponseHead, Payload), SendRequestError>>;
|
||||
|
||||
fn protocol(&self) -> Protocol;
|
||||
|
||||
/// Send request and body
|
||||
fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
|
||||
self,
|
||||
head: H,
|
||||
body: B,
|
||||
) -> Self::Future;
|
||||
|
||||
type TunnelFuture: Future<
|
||||
Output = Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
|
||||
>;
|
||||
|
||||
/// Send request, returns Response and Framed
|
||||
fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture;
|
||||
}
|
||||
|
||||
pub(super) trait ConnectionLifetime:
|
||||
AsyncRead + AsyncWrite + Unpin + 'static
|
||||
{
|
||||
/// Close connection
|
||||
fn close(&mut self);
|
||||
|
||||
/// Release connection to the connection pool
|
||||
fn release(&mut self);
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
/// HTTP client connection
|
||||
pub(super) struct IoConnection<T> {
|
||||
io: Option<ConnectionType<T>>,
|
||||
pub struct Connection {
|
||||
io: Option<ConnectionType>,
|
||||
created: time::Instant,
|
||||
pool: Option<Acquired<T>>,
|
||||
pool: Option<Acquired>,
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for IoConnection<T>
|
||||
where
|
||||
T: fmt::Debug,
|
||||
{
|
||||
impl fmt::Debug for Connection {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.io {
|
||||
Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io),
|
||||
Some(ConnectionType::H1(_)) => write!(f, "H1Connection"),
|
||||
Some(ConnectionType::H2(_)) => write!(f, "H2Connection"),
|
||||
None => write!(f, "Connection(Empty)"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IoConnection<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
impl Connection {
|
||||
pub(super) fn new(
|
||||
io: ConnectionType<T>,
|
||||
io: ConnectionType,
|
||||
created: time::Instant,
|
||||
pool: Option<Acquired<T>>,
|
||||
pool: Option<Acquired>,
|
||||
) -> Self {
|
||||
IoConnection {
|
||||
Self {
|
||||
pool,
|
||||
created,
|
||||
io: Some(io),
|
||||
|
@ -97,20 +61,11 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) fn into_inner(self) -> (ConnectionType<T>, time::Instant) {
|
||||
pub(super) fn into_inner(self) -> (ConnectionType, time::Instant) {
|
||||
(self.io.unwrap(), self.created)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Connection for IoConnection<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
type Io = T;
|
||||
type Future =
|
||||
Pin<Box<dyn Future<Output = Result<(ResponseHead, Payload), SendRequestError>>>>;
|
||||
|
||||
fn protocol(&self) -> Protocol {
|
||||
pub fn protocol(&self) -> Protocol {
|
||||
match self.io {
|
||||
Some(ConnectionType::H1(_)) => Protocol::Http1,
|
||||
Some(ConnectionType::H2(_)) => Protocol::Http2,
|
||||
|
@ -118,58 +73,42 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
|
||||
pub(super) async fn send_request<
|
||||
B: MessageBody + 'static,
|
||||
H: Into<RequestHeadType>,
|
||||
>(
|
||||
mut self,
|
||||
head: H,
|
||||
body: B,
|
||||
) -> Self::Future {
|
||||
) -> Result<(ResponseHead, Payload), SendRequestError> {
|
||||
match self.io.take().unwrap() {
|
||||
ConnectionType::H1(io) => Box::pin(h1proto::send_request(
|
||||
io,
|
||||
head.into(),
|
||||
body,
|
||||
self.created,
|
||||
self.pool,
|
||||
)),
|
||||
ConnectionType::H2(io) => Box::pin(h2proto::send_request(
|
||||
io,
|
||||
head.into(),
|
||||
body,
|
||||
self.created,
|
||||
self.pool,
|
||||
)),
|
||||
ConnectionType::H1(io) => {
|
||||
h1proto::send_request(io, head.into(), body, self.created, self.pool)
|
||||
.await
|
||||
}
|
||||
ConnectionType::H2(io) => {
|
||||
h2proto::send_request(io, head.into(), body, self.created, self.pool)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type TunnelFuture = Either<
|
||||
Pin<
|
||||
Box<
|
||||
dyn Future<
|
||||
Output = Result<
|
||||
(ResponseHead, Framed<Self::Io, ClientCodec>),
|
||||
SendRequestError,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
Ready<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
|
||||
>;
|
||||
|
||||
/// Send request, returns Response and Framed
|
||||
fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture {
|
||||
pub(super) async fn open_tunnel<H: Into<RequestHeadType>>(
|
||||
mut self,
|
||||
head: H,
|
||||
) -> Result<(ResponseHead, IoBoxed, ClientCodec), SendRequestError> {
|
||||
match self.io.take().unwrap() {
|
||||
ConnectionType::H1(io) => {
|
||||
Either::Left(Box::pin(h1proto::open_tunnel(io, head.into())))
|
||||
}
|
||||
ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await,
|
||||
ConnectionType::H2(io) => {
|
||||
if let Some(mut pool) = self.pool.take() {
|
||||
pool.release(IoConnection::new(
|
||||
pool.release(Connection::new(
|
||||
ConnectionType::H2(io),
|
||||
self.created,
|
||||
None,
|
||||
));
|
||||
}
|
||||
Either::Right(Ready::Err(SendRequestError::TunnelNotSupported))
|
||||
Err(SendRequestError::TunnelNotSupported)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::{rc::Rc, task::Context, task::Poll, time::Duration};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite};
|
||||
use crate::connect::{Connect as TcpConnect, Connector as TcpConnector};
|
||||
use crate::http::{Protocol, Uri};
|
||||
use crate::io::{Filter, Io, IoBoxed};
|
||||
use crate::service::{apply_fn, boxed, Service};
|
||||
use crate::time::{Millis, Seconds};
|
||||
use crate::util::timeout::{TimeoutError, TimeoutService};
|
||||
|
@ -16,13 +16,12 @@ use super::Connect;
|
|||
#[cfg(feature = "openssl")]
|
||||
use crate::connect::openssl::SslConnector as OpensslConnector;
|
||||
|
||||
#[cfg(feature = "rustls")]
|
||||
use crate::connect::rustls::ClientConfig;
|
||||
#[cfg(feature = "rustls")]
|
||||
use std::sync::Arc;
|
||||
//#[cfg(feature = "rustls")]
|
||||
//use crate::connect::rustls::ClientConfig;
|
||||
//#[cfg(feature = "rustls")]
|
||||
//use std::sync::Arc;
|
||||
|
||||
type BoxedConnector =
|
||||
boxed::BoxService<TcpConnect<Uri>, (Box<dyn Io>, Protocol), ConnectError>;
|
||||
type BoxedConnector = boxed::BoxService<TcpConnect<Uri>, IoBoxed, ConnectError>;
|
||||
|
||||
/// Manages http client network connectivity.
|
||||
///
|
||||
|
@ -47,9 +46,6 @@ pub struct Connector {
|
|||
ssl_connector: Option<BoxedConnector>,
|
||||
}
|
||||
|
||||
trait Io: AsyncRead + AsyncWrite + Unpin {}
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
|
||||
|
||||
impl Default for Connector {
|
||||
fn default() -> Self {
|
||||
Connector::new()
|
||||
|
@ -61,7 +57,7 @@ impl Connector {
|
|||
let conn = Connector {
|
||||
connector: boxed::service(
|
||||
TcpConnector::new()
|
||||
.map(|io| (Box::new(io) as Box<dyn Io>, Protocol::Http1))
|
||||
.map(|io| io.into_boxed())
|
||||
.map_err(ConnectError::from),
|
||||
),
|
||||
ssl_connector: None,
|
||||
|
@ -82,28 +78,28 @@ impl Connector {
|
|||
.map_err(|e| error!("Cannot set ALPN protocol: {:?}", e));
|
||||
conn.openssl(ssl.build())
|
||||
}
|
||||
#[cfg(all(not(feature = "openssl"), feature = "rustls"))]
|
||||
{
|
||||
use rust_tls::{OwnedTrustAnchor, RootCertStore};
|
||||
// #[cfg(all(not(feature = "openssl"), feature = "rustls"))]
|
||||
// {
|
||||
// use rust_tls::{OwnedTrustAnchor, RootCertStore};
|
||||
|
||||
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
let mut cert_store = RootCertStore::empty();
|
||||
cert_store.add_server_trust_anchors(
|
||||
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}),
|
||||
);
|
||||
let mut config = ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(cert_store)
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = protos;
|
||||
conn.rustls(Arc::new(config))
|
||||
}
|
||||
// let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
// let mut cert_store = RootCertStore::empty();
|
||||
// cert_store.add_server_trust_anchors(
|
||||
// webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||
// OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
// ta.subject,
|
||||
// ta.spki,
|
||||
// ta.name_constraints,
|
||||
// )
|
||||
// }),
|
||||
// );
|
||||
// let mut config = ClientConfig::builder()
|
||||
// .with_safe_defaults()
|
||||
// .with_root_certificates(cert_store)
|
||||
// .with_no_client_auth();
|
||||
// config.alpn_protocols = protos;
|
||||
// conn.rustls(Arc::new(config))
|
||||
// }
|
||||
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
|
||||
{
|
||||
conn
|
||||
|
@ -126,41 +122,29 @@ impl Connector {
|
|||
pub fn openssl(self, connector: OpensslConnector) -> Self {
|
||||
use crate::connect::openssl::OpensslConnector;
|
||||
|
||||
const H2: &[u8] = b"h2";
|
||||
self.secure_connector(OpensslConnector::new(connector).map(|sock| {
|
||||
let h2 = sock
|
||||
.ssl()
|
||||
.selected_alpn_protocol()
|
||||
.map(|protos| protos.windows(2).any(|w| w == H2))
|
||||
.unwrap_or(false);
|
||||
if h2 {
|
||||
(sock, Protocol::Http2)
|
||||
} else {
|
||||
(sock, Protocol::Http1)
|
||||
}
|
||||
}))
|
||||
self.secure_connector(OpensslConnector::new(connector))
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls")]
|
||||
/// Use rustls connector for secured connections.
|
||||
pub fn rustls(self, connector: Arc<ClientConfig>) -> Self {
|
||||
use crate::connect::rustls::RustlsConnector;
|
||||
// #[cfg(feature = "rustls")]
|
||||
// /// Use rustls connector for secured connections.
|
||||
// pub fn rustls(self, connector: Arc<ClientConfig>) -> Self {
|
||||
// use crate::connect::rustls::RustlsConnector;
|
||||
|
||||
const H2: &[u8] = b"h2";
|
||||
self.secure_connector(RustlsConnector::new(connector).map(|sock| {
|
||||
let h2 = sock
|
||||
.get_ref()
|
||||
.1
|
||||
.alpn_protocol()
|
||||
.map(|protos| protos.windows(2).any(|w| w == H2))
|
||||
.unwrap_or(false);
|
||||
if h2 {
|
||||
(Box::new(sock) as Box<dyn Io>, Protocol::Http2)
|
||||
} else {
|
||||
(Box::new(sock) as Box<dyn Io>, Protocol::Http1)
|
||||
}
|
||||
}))
|
||||
}
|
||||
// const H2: &[u8] = b"h2";
|
||||
// self.secure_connector(RustlsConnector::new(connector).map(|sock| {
|
||||
// let h2 = sock
|
||||
// .get_ref()
|
||||
// .1
|
||||
// .alpn_protocol()
|
||||
// .map(|protos| protos.windows(2).any(|w| w == H2))
|
||||
// .unwrap_or(false);
|
||||
// if h2 {
|
||||
// (Box::new(sock) as Box<dyn Io>, Protocol::Http2)
|
||||
// } else {
|
||||
// (Box::new(sock) as Box<dyn Io>, Protocol::Http1)
|
||||
// }
|
||||
// }))
|
||||
// }
|
||||
|
||||
/// Set total number of simultaneous connections per type of scheme.
|
||||
///
|
||||
|
@ -206,36 +190,36 @@ impl Connector {
|
|||
}
|
||||
|
||||
/// Use custom connector to open un-secured connections.
|
||||
pub fn connector<T, U>(mut self, connector: T) -> Self
|
||||
pub fn connector<T, U, F>(mut self, connector: T) -> Self
|
||||
where
|
||||
U: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
T: Service<
|
||||
Request = TcpConnect<Uri>,
|
||||
Response = (U, Protocol),
|
||||
Response = Io<F>,
|
||||
Error = crate::connect::ConnectError,
|
||||
> + 'static,
|
||||
F: Filter,
|
||||
{
|
||||
self.connector = boxed::service(
|
||||
connector
|
||||
.map(|(io, proto)| (Box::new(io) as Box<dyn Io>, proto))
|
||||
.map(|io| io.into_boxed())
|
||||
.map_err(ConnectError::from),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
/// Use custom connector to open secure connections.
|
||||
pub fn secure_connector<T, U>(mut self, connector: T) -> Self
|
||||
pub fn secure_connector<T, F>(mut self, connector: T) -> Self
|
||||
where
|
||||
U: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
T: Service<
|
||||
Request = TcpConnect<Uri>,
|
||||
Response = (U, Protocol),
|
||||
Response = Io<F>,
|
||||
Error = crate::connect::ConnectError,
|
||||
> + 'static,
|
||||
F: Filter,
|
||||
{
|
||||
self.ssl_connector = Some(boxed::service(
|
||||
connector
|
||||
.map(|(io, proto)| (Box::new(io) as Box<dyn Io>, proto))
|
||||
.map(|io| io.into_boxed())
|
||||
.map_err(ConnectError::from),
|
||||
));
|
||||
self
|
||||
|
@ -246,12 +230,13 @@ impl Connector {
|
|||
/// its combinator chain.
|
||||
pub fn finish(
|
||||
self,
|
||||
) -> impl Service<Request = Connect, Response = impl Connection, Error = ConnectError>
|
||||
+ Clone {
|
||||
let tcp_service = connector(self.connector, self.timeout);
|
||||
) -> impl Service<Request = Connect, Response = Connection, Error = ConnectError> + Clone
|
||||
{
|
||||
let tcp_service =
|
||||
connector(self.connector, self.timeout, self.disconnect_timeout);
|
||||
|
||||
let ssl_pool = if let Some(ssl_connector) = self.ssl_connector {
|
||||
let srv = connector(ssl_connector, self.timeout);
|
||||
let srv = connector(ssl_connector, self.timeout, self.disconnect_timeout);
|
||||
Some(ConnectionPool::new(
|
||||
srv,
|
||||
self.conn_lifetime,
|
||||
|
@ -279,9 +264,10 @@ impl Connector {
|
|||
fn connector(
|
||||
connector: BoxedConnector,
|
||||
timeout: Millis,
|
||||
disconnect_timeout: Millis,
|
||||
) -> impl Service<
|
||||
Request = Connect,
|
||||
Response = (Box<dyn Io>, Protocol),
|
||||
Response = IoBoxed,
|
||||
Error = ConnectError,
|
||||
Future = impl Unpin,
|
||||
> + Unpin {
|
||||
|
@ -290,6 +276,10 @@ fn connector(
|
|||
apply_fn(connector, |msg: Connect, srv| {
|
||||
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
|
||||
})
|
||||
.map(move |io: IoBoxed| {
|
||||
io.set_disconnect_timeout(disconnect_timeout);
|
||||
io
|
||||
})
|
||||
.map_err(ConnectError::from),
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
|
@ -298,28 +288,25 @@ fn connector(
|
|||
})
|
||||
}
|
||||
|
||||
type Pool<T> = ConnectionPool<T, Box<dyn Io>>;
|
||||
|
||||
struct InnerConnector<T> {
|
||||
tcp_pool: Pool<T>,
|
||||
ssl_pool: Option<Pool<T>>,
|
||||
tcp_pool: ConnectionPool<T>,
|
||||
ssl_pool: Option<ConnectionPool<T>>,
|
||||
}
|
||||
|
||||
impl<T> Service for InnerConnector<T>
|
||||
where
|
||||
T: Service<
|
||||
Request = Connect,
|
||||
Response = (Box<dyn Io>, Protocol),
|
||||
Error = ConnectError,
|
||||
> + Unpin
|
||||
T: Service<Request = Connect, Response = IoBoxed, Error = ConnectError>
|
||||
+ Unpin
|
||||
+ 'static,
|
||||
T::Future: Unpin,
|
||||
{
|
||||
type Request = Connect;
|
||||
type Response = <Pool<T> as Service>::Response;
|
||||
type Response = <ConnectionPool<T> as Service>::Response;
|
||||
type Error = ConnectError;
|
||||
type Future =
|
||||
Either<<Pool<T> as Service>::Future, Ready<Self::Response, Self::Error>>;
|
||||
type Future = Either<
|
||||
<ConnectionPool<T> as Service>::Future,
|
||||
Ready<Self::Response, Self::Error>,
|
||||
>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
|
|
|
@ -1,28 +1,27 @@
|
|||
use std::{io, io::Write, pin::Pin, task::Context, task::Poll, time};
|
||||
use std::{io, io::Write, pin::Pin, task::Context, task::Poll, time::Instant};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
|
||||
use crate::http::body::{BodySize, MessageBody};
|
||||
use crate::http::error::PayloadError;
|
||||
use crate::http::h1;
|
||||
use crate::http::header::{HeaderMap, HeaderValue, HOST};
|
||||
use crate::http::message::{RequestHeadType, ResponseHead};
|
||||
use crate::http::payload::{Payload, PayloadStream};
|
||||
use crate::io::IoBoxed;
|
||||
use crate::util::{next, poll_fn, send, BufMut, Bytes, BytesMut};
|
||||
use crate::{Sink, Stream};
|
||||
|
||||
use super::connection::{ConnectionLifetime, ConnectionType, IoConnection};
|
||||
use super::connection::{Connection, ConnectionType};
|
||||
use super::error::{ConnectError, SendRequestError};
|
||||
use super::pool::Acquired;
|
||||
|
||||
pub(super) async fn send_request<T, B>(
|
||||
io: T,
|
||||
pub(super) async fn send_request<B>(
|
||||
io: IoBoxed,
|
||||
mut head: RequestHeadType,
|
||||
body: B,
|
||||
created: time::Instant,
|
||||
pool: Option<Acquired<T>>,
|
||||
created: Instant,
|
||||
pool: Option<Acquired>,
|
||||
) -> Result<(ResponseHead, Payload), SendRequestError>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
B: MessageBody,
|
||||
{
|
||||
// set request host header
|
||||
|
@ -52,209 +51,130 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
let io = H1Connection {
|
||||
created,
|
||||
pool,
|
||||
io: Some(io),
|
||||
};
|
||||
// let io = H1Connection {
|
||||
// created,
|
||||
// pool,
|
||||
// io: Some(io),
|
||||
// };
|
||||
|
||||
// create Framed and send request
|
||||
let mut framed = Framed::new(io, h1::ClientCodec::default());
|
||||
send(&mut framed, (head, body.size()).into()).await?;
|
||||
// send request
|
||||
let codec = h1::ClientCodec::default();
|
||||
io.send((head, body.size()).into(), &codec).await?;
|
||||
|
||||
// send request body
|
||||
match body.size() {
|
||||
BodySize::None | BodySize::Empty | BodySize::Sized(0) => (),
|
||||
_ => send_body(body, &mut framed).await?,
|
||||
_ => {
|
||||
send_body(body, &io, &codec).await?;
|
||||
}
|
||||
};
|
||||
|
||||
// read response and init read body
|
||||
let head = if let Some(result) = next(&mut framed).await {
|
||||
result.map_err(SendRequestError::from)?
|
||||
let head = if let Some(result) = io.next(&codec).await? {
|
||||
result
|
||||
} else {
|
||||
return Err(SendRequestError::from(ConnectError::Disconnected));
|
||||
};
|
||||
|
||||
match framed.get_codec().message_type() {
|
||||
match codec.message_type() {
|
||||
h1::MessageType::None => {
|
||||
let force_close = !framed.get_codec().keepalive();
|
||||
release_connection(framed, force_close);
|
||||
let force_close = !codec.keepalive();
|
||||
release_connection(io, force_close, created, pool);
|
||||
Ok((head, Payload::None))
|
||||
}
|
||||
_ => {
|
||||
let pl: PayloadStream = Box::pin(PlStream::new(framed));
|
||||
let pl: PayloadStream = Box::pin(PlStream::new(io, codec, created, pool));
|
||||
Ok((head, pl.into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn open_tunnel<T>(
|
||||
io: T,
|
||||
pub(super) async fn open_tunnel(
|
||||
io: IoBoxed,
|
||||
head: RequestHeadType,
|
||||
) -> Result<(ResponseHead, Framed<T, h1::ClientCodec>), SendRequestError>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
) -> Result<(ResponseHead, IoBoxed, h1::ClientCodec), SendRequestError> {
|
||||
// create Framed and send request
|
||||
let mut framed = Framed::new(io, h1::ClientCodec::default());
|
||||
send(&mut framed, (head, BodySize::None).into()).await?;
|
||||
let codec = h1::ClientCodec::default();
|
||||
io.send((head, BodySize::None).into(), &codec).await?;
|
||||
|
||||
// read response
|
||||
if let Some(result) = next(&mut framed).await {
|
||||
let head = result.map_err(SendRequestError::from)?;
|
||||
Ok((head, framed))
|
||||
if let Some(head) = io.next(&codec).await? {
|
||||
Ok((head, io, codec))
|
||||
} else {
|
||||
Err(SendRequestError::from(ConnectError::Disconnected))
|
||||
}
|
||||
}
|
||||
|
||||
/// send request body to the peer
|
||||
pub(super) async fn send_body<I, B>(
|
||||
pub(super) async fn send_body<B>(
|
||||
mut body: B,
|
||||
framed: &mut Framed<I, h1::ClientCodec>,
|
||||
io: &IoBoxed,
|
||||
codec: &h1::ClientCodec,
|
||||
) -> Result<(), SendRequestError>
|
||||
where
|
||||
I: ConnectionLifetime,
|
||||
B: MessageBody,
|
||||
{
|
||||
let mut eof = false;
|
||||
while !eof {
|
||||
while !eof && !framed.is_write_buf_full() {
|
||||
match poll_fn(|cx| body.poll_next_chunk(cx)).await {
|
||||
Some(result) => {
|
||||
framed.write(h1::Message::Chunk(Some(result?)))?;
|
||||
}
|
||||
None => {
|
||||
eof = true;
|
||||
framed.write(h1::Message::Chunk(None))?;
|
||||
let wrt = io.write();
|
||||
|
||||
loop {
|
||||
match poll_fn(|cx| body.poll_next_chunk(cx)).await {
|
||||
Some(result) => {
|
||||
if !wrt.encode(h1::Message::Chunk(Some(result?)), codec)? {
|
||||
wrt.flush(false).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !framed.is_write_buf_empty() {
|
||||
poll_fn(|cx| match framed.flush(cx) {
|
||||
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
|
||||
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
|
||||
Poll::Pending => {
|
||||
if !framed.is_write_buf_full() {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
None => {
|
||||
wrt.encode(h1::Message::Chunk(None), codec)?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
poll_fn(|cx| Pin::new(&mut *framed).poll_flush(cx)).await?;
|
||||
wrt.flush(true).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
/// HTTP client connection
|
||||
pub(super) struct H1Connection<T> {
|
||||
io: Option<T>,
|
||||
created: time::Instant,
|
||||
pool: Option<Acquired<T>>,
|
||||
pub(super) struct PlStream {
|
||||
io: Option<IoBoxed>,
|
||||
codec: h1::ClientPayloadCodec,
|
||||
created: Instant,
|
||||
pool: Option<Acquired>,
|
||||
}
|
||||
|
||||
impl<T> ConnectionLifetime for H1Connection<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
/// Close connection
|
||||
fn close(&mut self) {
|
||||
if let Some(mut pool) = self.pool.take() {
|
||||
if let Some(io) = self.io.take() {
|
||||
pool.close(IoConnection::new(
|
||||
ConnectionType::H1(io),
|
||||
self.created,
|
||||
None,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Release this connection to the connection pool
|
||||
fn release(&mut self) {
|
||||
if let Some(mut pool) = self.pool.take() {
|
||||
if let Some(io) = self.io.take() {
|
||||
pool.release(IoConnection::new(
|
||||
ConnectionType::H1(io),
|
||||
self.created,
|
||||
None,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for H1Connection<T> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncWrite for H1Connection<T> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.io.as_mut().unwrap()).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(self.io.as_mut().unwrap()).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), io::Error>> {
|
||||
Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct PlStream<Io> {
|
||||
framed: Option<Framed<Io, h1::ClientPayloadCodec>>,
|
||||
}
|
||||
|
||||
impl<Io: ConnectionLifetime> PlStream<Io> {
|
||||
fn new(framed: Framed<Io, h1::ClientCodec>) -> Self {
|
||||
impl PlStream {
|
||||
fn new(
|
||||
io: IoBoxed,
|
||||
codec: h1::ClientCodec,
|
||||
created: Instant,
|
||||
pool: Option<Acquired>,
|
||||
) -> Self {
|
||||
PlStream {
|
||||
framed: Some(framed.map_codec(|codec| codec.into_payload_codec())),
|
||||
io: Some(io),
|
||||
codec: codec.into_payload_codec(),
|
||||
created,
|
||||
pool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
|
||||
impl Stream for PlStream {
|
||||
type Item = Result<Bytes, PayloadError>;
|
||||
|
||||
fn poll_next(
|
||||
self: Pin<&mut Self>,
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
let mut this = self.as_mut();
|
||||
|
||||
match this.framed.as_mut().unwrap().next_item(cx)? {
|
||||
match this.io.as_ref().unwrap().poll_next(&this.codec, cx)? {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Some(chunk)) => {
|
||||
if let Some(chunk) = chunk {
|
||||
Poll::Ready(Some(Ok(chunk)))
|
||||
} else {
|
||||
let framed = this.framed.take().unwrap();
|
||||
let force_close = !framed.get_codec().keepalive();
|
||||
release_connection(framed, force_close);
|
||||
let io = this.io.take().unwrap();
|
||||
let force_close = !this.codec.keepalive();
|
||||
release_connection(io, force_close, this.created, this.pool.take());
|
||||
Poll::Ready(None)
|
||||
}
|
||||
}
|
||||
|
@ -263,14 +183,17 @@ impl<Io: ConnectionLifetime> Stream for PlStream<Io> {
|
|||
}
|
||||
}
|
||||
|
||||
fn release_connection<T, U>(framed: Framed<T, U>, force_close: bool)
|
||||
where
|
||||
T: ConnectionLifetime,
|
||||
{
|
||||
let mut parts = framed.into_parts();
|
||||
if !force_close && parts.read_buf.is_empty() && parts.write_buf.is_empty() {
|
||||
parts.io.release()
|
||||
} else {
|
||||
parts.io.close()
|
||||
fn release_connection(
|
||||
io: IoBoxed,
|
||||
force_close: bool,
|
||||
created: Instant,
|
||||
mut pool: Option<Acquired>,
|
||||
) {
|
||||
if force_close || io.is_closed() || io.read().with_buf(|buf| !buf.is_empty()) {
|
||||
if let Some(mut pool) = pool.take() {
|
||||
pool.close(Connection::new(ConnectionType::H1(io), created, None));
|
||||
}
|
||||
} else if let Some(mut pool) = pool.take() {
|
||||
pool.release(Connection::new(ConnectionType::H1(io), created, None));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,19 +11,18 @@ use crate::http::message::{RequestHeadType, ResponseHead};
|
|||
use crate::http::payload::Payload;
|
||||
use crate::util::{poll_fn, Bytes};
|
||||
|
||||
use super::connection::{ConnectionType, IoConnection};
|
||||
use super::connection::{Connection, ConnectionType};
|
||||
use super::error::SendRequestError;
|
||||
use super::pool::Acquired;
|
||||
|
||||
pub(super) async fn send_request<T, B>(
|
||||
pub(super) async fn send_request<B>(
|
||||
mut io: SendRequest<Bytes>,
|
||||
head: RequestHeadType,
|
||||
body: B,
|
||||
created: time::Instant,
|
||||
pool: Option<Acquired<T>>,
|
||||
pool: Option<Acquired>,
|
||||
) -> Result<(ResponseHead, Payload), SendRequestError>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
B: MessageBody,
|
||||
{
|
||||
trace!("Sending client request: {:?} {:?}", head, body.size());
|
||||
|
@ -161,17 +160,17 @@ async fn send_body<B: MessageBody>(
|
|||
}
|
||||
|
||||
// release SendRequest object
|
||||
fn release<T: AsyncRead + AsyncWrite + Unpin + 'static>(
|
||||
fn release(
|
||||
io: SendRequest<Bytes>,
|
||||
pool: Option<Acquired<T>>,
|
||||
pool: Option<Acquired>,
|
||||
created: time::Instant,
|
||||
close: bool,
|
||||
) {
|
||||
if let Some(mut pool) = pool {
|
||||
if close {
|
||||
pool.close(IoConnection::new(ConnectionType::H2(io), created, None));
|
||||
pool.close(Connection::new(ConnectionType::H2(io), created, None));
|
||||
} else {
|
||||
pool.release(IoConnection::new(ConnectionType::H2(io), created, None));
|
||||
pool.release(Connection::new(ConnectionType::H2(io), created, None));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,7 +34,6 @@ mod test;
|
|||
pub mod ws;
|
||||
|
||||
pub use self::builder::ClientBuilder;
|
||||
pub use self::connect::BoxedSocket;
|
||||
pub use self::connection::Connection;
|
||||
pub use self::connector::Connector;
|
||||
pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder};
|
||||
|
@ -47,7 +46,7 @@ use crate::http::error::HttpError;
|
|||
use crate::http::{HeaderMap, Method, RequestHead, Uri};
|
||||
use crate::time::Millis;
|
||||
|
||||
use self::connect::{Connect as InnerConnect, ConnectorWrapper};
|
||||
use self::connect::{Connect as HttpConnect, ConnectorWrapper};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Connect {
|
||||
|
@ -76,7 +75,7 @@ pub struct Connect {
|
|||
pub struct Client(Rc<ClientConfig>);
|
||||
|
||||
pub(self) struct ClientConfig {
|
||||
pub(self) connector: Box<dyn InnerConnect>,
|
||||
pub(self) connector: Box<dyn HttpConnect>,
|
||||
pub(self) headers: HeaderMap,
|
||||
pub(self) timeout: Millis,
|
||||
}
|
||||
|
|
|
@ -2,19 +2,20 @@ use std::task::{Context, Poll};
|
|||
use std::time::{Duration, Instant};
|
||||
use std::{cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc};
|
||||
|
||||
use h2::client::{Builder, Connection, SendRequest};
|
||||
use h2::client::{Builder, Connection as H2Connection, SendRequest};
|
||||
use http::uri::Authority;
|
||||
|
||||
use crate::channel::pool;
|
||||
use crate::codec::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use crate::http::Protocol;
|
||||
use crate::io::IoBoxed;
|
||||
use crate::rt::spawn;
|
||||
use crate::service::Service;
|
||||
use crate::task::LocalWaker;
|
||||
use crate::time::{now, sleep, Millis, Sleep};
|
||||
use crate::util::{poll_fn, Bytes, HashMap};
|
||||
|
||||
use super::connection::{ConnectionType, IoConnection};
|
||||
use super::connection::{Connection, ConnectionType};
|
||||
use super::error::ConnectError;
|
||||
use super::Connect;
|
||||
|
||||
|
@ -29,16 +30,15 @@ impl From<Authority> for Key {
|
|||
}
|
||||
}
|
||||
|
||||
type Waiter<Io> = pool::Sender<Result<IoConnection<Io>, ConnectError>>;
|
||||
type WaiterReceiver<Io> = pool::Receiver<Result<IoConnection<Io>, ConnectError>>;
|
||||
type Waiter = pool::Sender<Result<Connection, ConnectError>>;
|
||||
type WaiterReceiver = pool::Receiver<Result<Connection, ConnectError>>;
|
||||
|
||||
/// Connections pool
|
||||
pub(super) struct ConnectionPool<T, Io: 'static>(Rc<T>, Rc<RefCell<Inner<Io>>>);
|
||||
pub(super) struct ConnectionPool<T>(Rc<T>, Rc<RefCell<Inner>>);
|
||||
|
||||
impl<T, Io> ConnectionPool<T, Io>
|
||||
impl<T> ConnectionPool<T>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
|
||||
T: Service<Request = Connect, Response = IoBoxed, Error = ConnectError>
|
||||
+ Unpin
|
||||
+ 'static,
|
||||
T::Future: Unpin,
|
||||
|
@ -73,35 +73,27 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T, Io> Drop for ConnectionPool<T, Io>
|
||||
where
|
||||
Io: 'static,
|
||||
{
|
||||
impl<T> Drop for ConnectionPool<T> {
|
||||
fn drop(&mut self) {
|
||||
self.1.borrow().waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, Io> Clone for ConnectionPool<T, Io>
|
||||
where
|
||||
Io: 'static,
|
||||
{
|
||||
impl<T> Clone for ConnectionPool<T> {
|
||||
fn clone(&self) -> Self {
|
||||
ConnectionPool(self.0.clone(), self.1.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, Io> Service for ConnectionPool<T, Io>
|
||||
impl<T> Service for ConnectionPool<T>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
|
||||
+ 'static,
|
||||
T: Service<Request = Connect, Response = IoBoxed, Error = ConnectError> + 'static,
|
||||
T::Future: Unpin,
|
||||
{
|
||||
type Request = Connect;
|
||||
type Response = IoConnection<Io>;
|
||||
type Response = Connection;
|
||||
type Error = ConnectError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<IoConnection<Io>, ConnectError>>>>;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Connection, ConnectError>>>>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
|
@ -127,11 +119,12 @@ where
|
|||
};
|
||||
|
||||
// acquire connection
|
||||
match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await {
|
||||
let result = inner.borrow_mut().acquire(&key);
|
||||
match result {
|
||||
// use existing connection
|
||||
Acquire::Acquired(io, created) => {
|
||||
trace!("Use existing connection for {:?}", req.uri);
|
||||
Ok(IoConnection::new(
|
||||
Ok(Connection::new(
|
||||
io,
|
||||
created,
|
||||
Some(Acquired(key, Some(inner))),
|
||||
|
@ -165,31 +158,31 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
enum Acquire<T> {
|
||||
Acquired(ConnectionType<T>, Instant),
|
||||
enum Acquire {
|
||||
Acquired(ConnectionType, Instant),
|
||||
Available,
|
||||
NotAvailable,
|
||||
}
|
||||
|
||||
struct AvailableConnection<Io> {
|
||||
io: ConnectionType<Io>,
|
||||
struct AvailableConnection {
|
||||
io: ConnectionType,
|
||||
used: Instant,
|
||||
created: Instant,
|
||||
}
|
||||
|
||||
pub(super) struct Inner<Io> {
|
||||
pub(super) struct Inner {
|
||||
conn_lifetime: Duration,
|
||||
conn_keep_alive: Duration,
|
||||
disconnect_timeout: Millis,
|
||||
limit: usize,
|
||||
acquired: usize,
|
||||
available: HashMap<Key, VecDeque<AvailableConnection<Io>>>,
|
||||
waiters: VecDeque<(Key, Connect, Waiter<Io>)>,
|
||||
available: HashMap<Key, VecDeque<AvailableConnection>>,
|
||||
waiters: VecDeque<(Key, Connect, Waiter)>,
|
||||
waker: LocalWaker,
|
||||
pool: pool::Pool<Result<IoConnection<Io>, ConnectError>>,
|
||||
pool: pool::Pool<Result<Connection, ConnectError>>,
|
||||
}
|
||||
|
||||
impl<Io> Inner<Io> {
|
||||
impl Inner {
|
||||
fn reserve(&mut self) {
|
||||
self.acquired += 1;
|
||||
}
|
||||
|
@ -199,12 +192,9 @@ impl<Io> Inner<Io> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<Io> Inner<Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
impl Inner {
|
||||
/// connection is not available, wait
|
||||
fn wait_for(&mut self, connect: Connect) -> WaiterReceiver<Io> {
|
||||
fn wait_for(&mut self, connect: Connect) -> WaiterReceiver {
|
||||
let (tx, rx) = self.pool.channel();
|
||||
let key: Key = connect.uri.authority().unwrap().clone().into();
|
||||
self.waiters.push_back((key, connect, tx));
|
||||
|
@ -226,7 +216,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn acquire(&mut self, key: &Key, cx: &mut Context<'_>) -> Acquire<Io> {
|
||||
fn acquire(&mut self, key: &Key) -> Acquire {
|
||||
self.cleanup();
|
||||
|
||||
// check limits
|
||||
|
@ -249,19 +239,22 @@ where
|
|||
CloseConnection::spawn(io, self.disconnect_timeout);
|
||||
}
|
||||
} else {
|
||||
let mut io = conn.io;
|
||||
let mut buf = [0; 2];
|
||||
let mut read_buf = ReadBuf::new(&mut buf);
|
||||
if let ConnectionType::H1(ref mut s) = io {
|
||||
match Pin::new(s).poll_read(cx, &mut read_buf) {
|
||||
Poll::Pending => (),
|
||||
Poll::Ready(Ok(_)) if !read_buf.filled().is_empty() => {
|
||||
if let ConnectionType::H1(io) = io {
|
||||
CloseConnection::spawn(io, self.disconnect_timeout);
|
||||
}
|
||||
continue;
|
||||
let io = conn.io;
|
||||
if let ConnectionType::H1(ref s) = io {
|
||||
if s.is_closed() {
|
||||
continue;
|
||||
}
|
||||
let is_valid = s.read().with_buf(|buf| {
|
||||
if buf.is_empty() || (buf.len() == 2 && &buf[..] == b"\r\n")
|
||||
{
|
||||
buf.clear();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
_ => continue,
|
||||
});
|
||||
if !is_valid {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return Acquire::Acquired(io, conn.created);
|
||||
|
@ -271,7 +264,7 @@ where
|
|||
Acquire::Available
|
||||
}
|
||||
|
||||
fn release_conn(&mut self, key: &Key, io: ConnectionType<Io>, created: Instant) {
|
||||
fn release_conn(&mut self, key: &Key, io: ConnectionType, created: Instant) {
|
||||
self.acquired -= 1;
|
||||
self.available
|
||||
.entry(key.clone())
|
||||
|
@ -284,7 +277,7 @@ where
|
|||
self.check_availibility();
|
||||
}
|
||||
|
||||
fn release_close(&mut self, io: ConnectionType<Io>) {
|
||||
fn release_close(&mut self, io: ConnectionType) {
|
||||
self.acquired -= 1;
|
||||
if let ConnectionType::H1(io) = io {
|
||||
CloseConnection::spawn(io, self.disconnect_timeout);
|
||||
|
@ -300,19 +293,14 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
struct ConnectionPoolSupport<T, Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
struct ConnectionPoolSupport<T> {
|
||||
connector: T,
|
||||
inner: Rc<RefCell<Inner<Io>>>,
|
||||
inner: Rc<RefCell<Inner>>,
|
||||
}
|
||||
|
||||
impl<T, Io> Future for ConnectionPoolSupport<T, Io>
|
||||
impl<T> Future for ConnectionPoolSupport<T>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
|
||||
+ Unpin,
|
||||
T: Service<Request = Connect, Response = IoBoxed, Error = ConnectError> + Unpin,
|
||||
T::Future: Unpin + 'static,
|
||||
{
|
||||
type Output = ();
|
||||
|
@ -337,11 +325,11 @@ where
|
|||
};
|
||||
let key = key.clone();
|
||||
|
||||
match inner.acquire(&key, cx) {
|
||||
match inner.acquire(&key) {
|
||||
Acquire::NotAvailable => break,
|
||||
Acquire::Acquired(io, created) => {
|
||||
let (key, _, tx) = inner.waiters.pop_front().unwrap();
|
||||
let _ = tx.send(Ok(IoConnection::new(
|
||||
let _ = tx.send(Ok(Connection::new(
|
||||
io,
|
||||
created,
|
||||
Some(Acquired(key.clone(), Some(this.inner.clone()))),
|
||||
|
@ -363,94 +351,43 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
struct CloseConnection<T> {
|
||||
io: T,
|
||||
struct CloseConnection {
|
||||
io: IoBoxed,
|
||||
timeout: Option<Sleep>,
|
||||
shutdown: bool,
|
||||
}
|
||||
|
||||
impl<T> CloseConnection<T>
|
||||
where
|
||||
T: AsyncWrite + AsyncRead + Unpin + 'static,
|
||||
{
|
||||
fn spawn(io: T, timeout: Millis) {
|
||||
spawn(Self {
|
||||
io,
|
||||
shutdown: false,
|
||||
timeout: timeout.map(sleep),
|
||||
impl CloseConnection {
|
||||
fn spawn(io: IoBoxed, timeout: Millis) {
|
||||
spawn(async move {
|
||||
io.shutdown().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for CloseConnection<T>
|
||||
where
|
||||
T: AsyncWrite + AsyncRead + Unpin,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
|
||||
let mut this = self.as_mut();
|
||||
|
||||
// shutdown WRITE side
|
||||
match Pin::new(&mut this.io).poll_shutdown(cx) {
|
||||
Poll::Ready(Ok(())) => this.shutdown = true,
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(_)) => return Poll::Ready(()),
|
||||
}
|
||||
|
||||
// read until 0 or err
|
||||
if let Some(ref timeout) = this.timeout {
|
||||
match timeout.poll_elapsed(cx) {
|
||||
Poll::Ready(_) => (),
|
||||
Poll::Pending => {
|
||||
let mut buf = [0u8; 512];
|
||||
let mut read_buf = ReadBuf::new(&mut buf);
|
||||
loop {
|
||||
match Pin::new(&mut this.io).poll_read(cx, &mut read_buf) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(_)) => return Poll::Ready(()),
|
||||
Poll::Ready(Ok(_)) => {
|
||||
if read_buf.filled().is_empty() {
|
||||
return Poll::Ready(());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
||||
|
||||
struct OpenConnection<F, Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
struct OpenConnection<F> {
|
||||
fut: F,
|
||||
h2: Option<
|
||||
Pin<
|
||||
Box<
|
||||
dyn Future<
|
||||
Output = Result<
|
||||
(SendRequest<Bytes>, Connection<Io, Bytes>),
|
||||
(SendRequest<Bytes>, H2Connection<Bytes>),
|
||||
h2::Error,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
>,
|
||||
tx: Option<Waiter<Io>>,
|
||||
guard: Option<OpenGuard<Io>>,
|
||||
tx: Option<Waiter>,
|
||||
guard: Option<OpenGuard>,
|
||||
}
|
||||
|
||||
impl<F, Io> OpenConnection<F, Io>
|
||||
impl<F> OpenConnection<F>
|
||||
where
|
||||
F: Future<Output = Result<(Io, Protocol), ConnectError>> + Unpin + 'static,
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
F: Future<Output = Result<IoBoxed, ConnectError>> + Unpin + 'static,
|
||||
{
|
||||
fn spawn(key: Key, tx: Waiter<Io>, inner: Rc<RefCell<Inner<Io>>>, fut: F) {
|
||||
fn spawn(key: Key, tx: Waiter, inner: Rc<RefCell<Inner>>, fut: F) {
|
||||
spawn(OpenConnection {
|
||||
fut,
|
||||
h2: None,
|
||||
|
@ -463,10 +400,9 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<F, Io> Future for OpenConnection<F, Io>
|
||||
impl<F> Future for OpenConnection<F>
|
||||
where
|
||||
F: Future<Output = Result<(Io, Protocol), ConnectError>> + Unpin,
|
||||
Io: AsyncRead + AsyncWrite + Unpin,
|
||||
F: Future<Output = Result<IoBoxed, ConnectError>> + Unpin,
|
||||
{
|
||||
type Output = ();
|
||||
|
||||
|
@ -478,7 +414,7 @@ where
|
|||
return match Pin::new(h2).poll(cx) {
|
||||
Poll::Ready(Ok((snd, connection))) => {
|
||||
// h2 connection is ready
|
||||
let conn = IoConnection::new(
|
||||
let conn = Connection::new(
|
||||
ConnectionType::H2(snd),
|
||||
now(),
|
||||
Some(this.guard.take().unwrap().consume()),
|
||||
|
@ -488,7 +424,7 @@ where
|
|||
conn.release()
|
||||
}
|
||||
spawn(async move {
|
||||
let _ = connection.await;
|
||||
// let _ = connection.await;
|
||||
});
|
||||
Poll::Ready(())
|
||||
}
|
||||
|
@ -511,52 +447,43 @@ where
|
|||
}
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Ready(Ok((io, proto))) => {
|
||||
Poll::Ready(Ok(io)) => {
|
||||
trace!("Connection is established");
|
||||
// handle http1 proto
|
||||
if proto == Protocol::Http1 {
|
||||
let conn = IoConnection::new(
|
||||
ConnectionType::H1(io),
|
||||
now(),
|
||||
Some(this.guard.take().unwrap().consume()),
|
||||
);
|
||||
if let Err(Ok(conn)) = this.tx.take().unwrap().send(Ok(conn)) {
|
||||
// waiter is gone, return connection to pool
|
||||
conn.release()
|
||||
}
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
// init http2 handshake
|
||||
this.h2 = Some(Box::pin(Builder::new().handshake(io)));
|
||||
self.poll(cx)
|
||||
//if proto == Protocol::Http1 {
|
||||
let conn = Connection::new(
|
||||
ConnectionType::H1(io),
|
||||
now(),
|
||||
Some(this.guard.take().unwrap().consume()),
|
||||
);
|
||||
if let Err(Ok(conn)) = this.tx.take().unwrap().send(Ok(conn)) {
|
||||
// waiter is gone, return connection to pool
|
||||
conn.release()
|
||||
}
|
||||
Poll::Ready(())
|
||||
// } else {
|
||||
// init http2 handshake
|
||||
// this.h2 = Some(Box::pin(Builder::new().handshake(io)));
|
||||
// self.poll(cx)
|
||||
//}
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct OpenGuard<Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
struct OpenGuard {
|
||||
key: Key,
|
||||
inner: Option<Rc<RefCell<Inner<Io>>>>,
|
||||
inner: Option<Rc<RefCell<Inner>>>,
|
||||
}
|
||||
|
||||
impl<Io> OpenGuard<Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
fn consume(mut self) -> Acquired<Io> {
|
||||
impl OpenGuard {
|
||||
fn consume(mut self) -> Acquired {
|
||||
Acquired(self.key.clone(), self.inner.take())
|
||||
}
|
||||
}
|
||||
|
||||
impl<Io> Drop for OpenGuard<Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
impl Drop for OpenGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Some(i) = self.inner.take() {
|
||||
let mut inner = i.as_ref().borrow_mut();
|
||||
|
@ -566,20 +493,17 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) struct Acquired<T>(Key, Option<Rc<RefCell<Inner<T>>>>);
|
||||
pub(super) struct Acquired(Key, Option<Rc<RefCell<Inner>>>);
|
||||
|
||||
impl<T> Acquired<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
pub(super) fn close(&mut self, conn: IoConnection<T>) {
|
||||
impl Acquired {
|
||||
pub(super) fn close(&mut self, conn: Connection) {
|
||||
if let Some(inner) = self.1.take() {
|
||||
let (io, _) = conn.into_inner();
|
||||
inner.as_ref().borrow_mut().release_close(io);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn release(&mut self, conn: IoConnection<T>) {
|
||||
pub(super) fn release(&mut self, conn: Connection) {
|
||||
if let Some(inner) = self.1.take() {
|
||||
let (io, created) = conn.into_inner();
|
||||
inner
|
||||
|
@ -590,7 +514,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Acquired<T> {
|
||||
impl Drop for Acquired {
|
||||
fn drop(&mut self) {
|
||||
if let Some(inner) = self.1.take() {
|
||||
inner.borrow_mut().release();
|
||||
|
|
|
@ -6,17 +6,16 @@ use coo_kie::{Cookie, CookieJar};
|
|||
use nanorand::{Rng, WyRand};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite, Framed};
|
||||
use crate::framed::{DispatchItem, Dispatcher, State};
|
||||
use crate::http::error::HttpError;
|
||||
use crate::http::header::{self, HeaderName, HeaderValue, AUTHORIZATION};
|
||||
use crate::http::{ConnectionType, Payload, RequestHead, StatusCode, Uri};
|
||||
use crate::io::{DefaultFilter, DispatchItem, Dispatcher, Filter, Io, IoBoxed};
|
||||
use crate::service::{apply_fn, into_service, IntoService, Service};
|
||||
use crate::util::Either;
|
||||
use crate::{channel::mpsc, rt, time::timeout, util::sink, util::Ready, ws};
|
||||
|
||||
pub use crate::ws::{CloseCode, CloseReason, Frame, Message};
|
||||
|
||||
use super::connect::BoxedSocket;
|
||||
use super::error::{InvalidUrl, SendRequestError, WsClientError};
|
||||
use super::response::ClientResponse;
|
||||
use super::ClientConfig;
|
||||
|
@ -311,7 +310,7 @@ impl WsRequest {
|
|||
let fut = self.config.connector.open_tunnel(head.into(), self.addr);
|
||||
|
||||
// set request timeout
|
||||
let (head, framed) = if self.config.timeout.non_zero() {
|
||||
let (head, io, codec) = if self.config.timeout.non_zero() {
|
||||
timeout(self.config.timeout, fut)
|
||||
.await
|
||||
.map_err(|_| SendRequestError::Timeout)
|
||||
|
@ -377,13 +376,12 @@ impl WsRequest {
|
|||
// response and ws io
|
||||
Ok(WsConnection::new(
|
||||
ClientResponse::new(head, Payload::None),
|
||||
framed.map_codec(|_| {
|
||||
if server_mode {
|
||||
ws::Codec::new().max_size(max_size)
|
||||
} else {
|
||||
ws::Codec::new().max_size(max_size).client_mode()
|
||||
}
|
||||
}),
|
||||
io,
|
||||
if server_mode {
|
||||
ws::Codec::new().max_size(max_size)
|
||||
} else {
|
||||
ws::Codec::new().max_size(max_size).client_mode()
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
@ -403,31 +401,20 @@ impl fmt::Debug for WsRequest {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct WsConnection<Io = BoxedSocket> {
|
||||
io: Io,
|
||||
state: State,
|
||||
pub struct WsConnection {
|
||||
io: IoBoxed,
|
||||
codec: ws::Codec,
|
||||
res: ClientResponse,
|
||||
}
|
||||
|
||||
impl<Io> WsConnection<Io>
|
||||
where
|
||||
Io: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
fn new(res: ClientResponse, framed: Framed<Io, ws::Codec>) -> Self {
|
||||
let (io, codec, state) = State::from_framed(framed);
|
||||
|
||||
Self {
|
||||
io,
|
||||
state,
|
||||
codec,
|
||||
res,
|
||||
}
|
||||
impl WsConnection {
|
||||
fn new(res: ClientResponse, io: IoBoxed, codec: ws::Codec) -> Self {
|
||||
Self { io, codec, res }
|
||||
}
|
||||
|
||||
/// Get ws sink
|
||||
pub fn sink(&self) -> ws::WsSink {
|
||||
ws::WsSink::new(self.state.clone(), self.codec.clone())
|
||||
ws::WsSink::new(self.io.get_ref(), self.codec.clone())
|
||||
}
|
||||
|
||||
/// Get reference to response
|
||||
|
@ -458,10 +445,10 @@ where
|
|||
}
|
||||
|
||||
/// Start client websockets service.
|
||||
pub async fn start<T, F>(self, service: F) -> Result<(), ws::WsError<T::Error>>
|
||||
pub async fn start<T, U>(self, service: U) -> Result<(), ws::WsError<T::Error>>
|
||||
where
|
||||
T: Service<Request = ws::Frame, Response = Option<ws::Message>> + 'static,
|
||||
F: IntoService<T>,
|
||||
U: IntoService<T>,
|
||||
{
|
||||
let service = apply_fn(
|
||||
service.into_service().map_err(ws::WsError::Service),
|
||||
|
@ -475,21 +462,22 @@ where
|
|||
DispatchItem::DecoderError(e) | DispatchItem::EncoderError(e) => {
|
||||
Either::Right(Ready::Err(ws::WsError::Protocol(e)))
|
||||
}
|
||||
DispatchItem::IoError(e) => {
|
||||
DispatchItem::Disconnect(Some(e)) => {
|
||||
Either::Right(Ready::Err(ws::WsError::Io(e)))
|
||||
}
|
||||
DispatchItem::Disconnect(None) => {
|
||||
Either::Right(Ready::Err(ws::WsError::Disconnected))
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
Dispatcher::new(self.io, self.codec, self.state, service, Default::default())
|
||||
.await
|
||||
Dispatcher::new(self.io, self.codec, service, Default::default()).await
|
||||
}
|
||||
|
||||
/// Consumes the `WsConnection`, returning it'as underlying I/O framed object
|
||||
/// and response.
|
||||
pub fn into_inner(self) -> (ClientResponse, Framed<Io, ws::Codec>) {
|
||||
let framed = self.state.into_framed(self.io, self.codec);
|
||||
(self.res, framed)
|
||||
pub fn into_inner(self) -> (ClientResponse, IoBoxed, ws::Codec) {
|
||||
(self.res, self.io, self.codec)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::{cell::Cell, cell::RefCell, ptr::copy_nonoverlapping, rc::Rc, time};
|
||||
use std::{cell::Cell, ptr::copy_nonoverlapping, rc::Rc, time};
|
||||
|
||||
use crate::framed::Timer;
|
||||
use crate::http::{Request, Response};
|
||||
use crate::io::{IoRef, Timer};
|
||||
use crate::service::boxed::BoxService;
|
||||
use crate::time::{sleep, Millis, Seconds, Sleep};
|
||||
use crate::util::{BytesMut, PoolId};
|
||||
|
@ -94,9 +94,9 @@ impl ServiceConfig {
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) type OnRequest<T> = BoxService<(Request, Rc<RefCell<T>>), Request, Response>;
|
||||
pub(super) type OnRequest = BoxService<(Request, IoRef), Request, Response>;
|
||||
|
||||
pub(super) struct DispatcherConfig<T, S, X, U> {
|
||||
pub(super) struct DispatcherConfig<S, X, U> {
|
||||
pub(super) service: S,
|
||||
pub(super) expect: X,
|
||||
pub(super) upgrade: Option<U>,
|
||||
|
@ -107,16 +107,16 @@ pub(super) struct DispatcherConfig<T, S, X, U> {
|
|||
pub(super) timer: DateService,
|
||||
pub(super) timer_h1: Timer,
|
||||
pub(super) pool: PoolId,
|
||||
pub(super) on_request: Option<OnRequest<T>>,
|
||||
pub(super) on_request: Option<OnRequest>,
|
||||
}
|
||||
|
||||
impl<T, S, X, U> DispatcherConfig<T, S, X, U> {
|
||||
impl<S, X, U> DispatcherConfig<S, X, U> {
|
||||
pub(super) fn new(
|
||||
cfg: ServiceConfig,
|
||||
service: S,
|
||||
expect: X,
|
||||
upgrade: Option<U>,
|
||||
on_request: Option<OnRequest<T>>,
|
||||
on_request: Option<OnRequest>,
|
||||
) -> Self {
|
||||
DispatcherConfig {
|
||||
service,
|
||||
|
@ -148,10 +148,6 @@ impl<T, S, X, U> DispatcherConfig<T, S, X, U> {
|
|||
self.keep_alive
|
||||
.map(|t| self.timer.now() + time::Duration::from(t))
|
||||
}
|
||||
|
||||
pub(super) fn now(&self) -> time::Instant {
|
||||
self.timer.now()
|
||||
}
|
||||
}
|
||||
|
||||
const DATE_VALUE_LENGTH_HDR: usize = 39;
|
||||
|
|
|
@ -322,7 +322,7 @@ impl MessageType for ResponseHead {
|
|||
Err(ParseError::TooLarge)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
//! Framed transport dispatcher
|
||||
use std::task::{Context, Poll};
|
||||
use std::{
|
||||
cell::RefCell, error::Error, fmt, future::Future, marker, net, pin::Pin, rc::Rc,
|
||||
time,
|
||||
};
|
||||
use std::{error::Error, fmt, future::Future, marker, net, pin::Pin, rc::Rc, time};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite};
|
||||
use crate::framed::{ReadTask, State as IoState, WriteTask};
|
||||
use crate::io::{Filter, Io, IoRef};
|
||||
use crate::service::Service;
|
||||
use crate::util::Bytes;
|
||||
use crate::{time::now, util::Bytes, util::Either};
|
||||
|
||||
use crate::http;
|
||||
use crate::http::body::{BodySize, MessageBody, ResponseBody};
|
||||
|
@ -37,19 +33,24 @@ bitflags::bitflags! {
|
|||
|
||||
pin_project_lite::pin_project! {
|
||||
/// Dispatcher for HTTP/1.1 protocol
|
||||
pub struct Dispatcher<T, S: Service, B, X: Service, U: Service> {
|
||||
pub struct Dispatcher<F, S: Service, B, X: Service, U: Service> {
|
||||
#[pin]
|
||||
call: CallState<S, X, U>,
|
||||
st: State<B>,
|
||||
inner: DispatcherInner<T, S, B, X, U>,
|
||||
inner: DispatcherInner<F, S, B, X, U>,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(derive_more::Display)]
|
||||
enum State<B> {
|
||||
Call,
|
||||
ReadRequest,
|
||||
ReadPayload,
|
||||
SendPayload { body: ResponseBody<B> },
|
||||
#[display(fmt = "State::SendPayload")]
|
||||
SendPayload {
|
||||
body: ResponseBody<B>,
|
||||
},
|
||||
#[display(fmt = "State::Upgrade")]
|
||||
Upgrade(Option<Request>),
|
||||
Stop,
|
||||
}
|
||||
|
@ -65,17 +66,15 @@ pin_project_lite::pin_project! {
|
|||
}
|
||||
}
|
||||
|
||||
struct DispatcherInner<T, S, B, X, U> {
|
||||
io: Option<Rc<RefCell<T>>>,
|
||||
struct DispatcherInner<F, S, B, X, U> {
|
||||
io: Option<Io<F>>,
|
||||
flags: Flags,
|
||||
codec: Codec,
|
||||
config: Rc<DispatcherConfig<T, S, X, U>>,
|
||||
state: IoState,
|
||||
state: IoRef,
|
||||
config: Rc<DispatcherConfig<S, X, U>>,
|
||||
expire: time::Instant,
|
||||
error: Option<DispatchError>,
|
||||
payload: Option<(PayloadDecoder, PayloadSender)>,
|
||||
peer_addr: Option<net::SocketAddr>,
|
||||
on_connect_data: Option<Box<dyn DataFactory>>,
|
||||
_t: marker::PhantomData<(S, B)>,
|
||||
}
|
||||
|
||||
|
@ -93,42 +92,33 @@ enum WritePayloadStatus<B> {
|
|||
Continue,
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
|
||||
impl<F, S, B, X, U> Dispatcher<F, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
F: Filter + 'static,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Response: Into<Response<B>>,
|
||||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError,
|
||||
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
|
||||
U: Service<Request = (Request, Io<F>, Codec), Response = ()>,
|
||||
U::Error: Error + fmt::Display,
|
||||
{
|
||||
/// Construct new `Dispatcher` instance with outgoing messages stream.
|
||||
pub(in crate::http) fn new(
|
||||
io: T,
|
||||
config: Rc<DispatcherConfig<T, S, X, U>>,
|
||||
peer_addr: Option<net::SocketAddr>,
|
||||
on_connect_data: Option<Box<dyn DataFactory>>,
|
||||
io: Io<F>,
|
||||
config: Rc<DispatcherConfig<S, X, U>>,
|
||||
) -> Self {
|
||||
let mut expire = now();
|
||||
let state = io.get_ref();
|
||||
let codec = Codec::new(config.timer.clone(), config.keep_alive_enabled());
|
||||
let state = IoState::with_memory_pool(config.pool.into());
|
||||
state.set_disconnect_timeout(config.client_disconnect);
|
||||
|
||||
let mut expire = config.timer_h1.now();
|
||||
let io = Rc::new(RefCell::new(io));
|
||||
|
||||
// slow-request timer
|
||||
if config.client_timeout.non_zero() {
|
||||
expire += std::time::Duration::from(config.client_timeout);
|
||||
expire += time::Duration::from(config.client_timeout);
|
||||
config.timer_h1.register(expire, expire, &state);
|
||||
}
|
||||
|
||||
// start support io tasks
|
||||
crate::rt::spawn(ReadTask::new(io.clone(), state.clone()));
|
||||
crate::rt::spawn(WriteTask::new(io.clone(), state.clone()));
|
||||
|
||||
Dispatcher {
|
||||
call: CallState::None,
|
||||
st: State::ReadRequest,
|
||||
|
@ -138,27 +128,25 @@ where
|
|||
error: None,
|
||||
payload: None,
|
||||
codec,
|
||||
config,
|
||||
state,
|
||||
config,
|
||||
expire,
|
||||
peer_addr,
|
||||
on_connect_data,
|
||||
_t: marker::PhantomData,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
|
||||
impl<F, S, B, X, U> Future for Dispatcher<F, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
F: Filter,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Response: Into<Response<B>>,
|
||||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
|
||||
U: Service<Request = (Request, Io<F>, Codec), Response = ()>,
|
||||
U::Error: Error + fmt::Display + 'static,
|
||||
{
|
||||
type Output = Result<(), DispatchError>;
|
||||
|
@ -204,7 +192,6 @@ where
|
|||
)
|
||||
});
|
||||
if this.inner.flags.contains(Flags::UPGRADE) {
|
||||
this.inner.state.stop_io(cx.waker());
|
||||
*this.st = State::Upgrade(Some(req));
|
||||
return Poll::Pending;
|
||||
} else {
|
||||
|
@ -292,7 +279,6 @@ where
|
|||
log::trace!("keep-alive timeout, close connection");
|
||||
}
|
||||
*this.st = State::Stop;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -307,7 +293,7 @@ where
|
|||
req,
|
||||
pl
|
||||
);
|
||||
req.head_mut().peer_addr = this.inner.peer_addr;
|
||||
req.head_mut().io = Some(this.inner.state.clone());
|
||||
|
||||
// configure request payload
|
||||
let upgrade = match pl {
|
||||
|
@ -340,16 +326,9 @@ where
|
|||
);
|
||||
}
|
||||
|
||||
// set on_connect data
|
||||
if let Some(ref on_connect) = this.inner.on_connect_data
|
||||
{
|
||||
on_connect.set(&mut req.extensions_mut());
|
||||
}
|
||||
|
||||
if upgrade {
|
||||
// Handle UPGRADE request
|
||||
log::trace!("prep io for upgrade handler");
|
||||
this.inner.state.stop_io(cx.waker());
|
||||
*this.st = State::Upgrade(Some(req));
|
||||
return Poll::Pending;
|
||||
} else {
|
||||
|
@ -361,11 +340,7 @@ where
|
|||
CallState::Filter {
|
||||
fut: f.call((
|
||||
req,
|
||||
this.inner
|
||||
.io
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
this.inner.state.clone(),
|
||||
)),
|
||||
}
|
||||
} else if req.head().expect() {
|
||||
|
@ -390,13 +365,13 @@ where
|
|||
if this.inner.flags.contains(Flags::STARTED)
|
||||
&& (!this.inner.flags.contains(Flags::KEEPALIVE)
|
||||
|| !this.inner.codec.keepalive_enabled()
|
||||
|| this.inner.state.is_io_err())
|
||||
|| !this.inner.state.is_io_open())
|
||||
{
|
||||
*this.st = State::Stop;
|
||||
this.inner.state.dispatcher_stopped();
|
||||
this.inner.state.stop_dispatcher();
|
||||
continue;
|
||||
}
|
||||
this.inner.state.read().wake(cx.waker());
|
||||
let _ = read.poll_ready(cx);
|
||||
return Poll::Pending;
|
||||
}
|
||||
Err(err) => {
|
||||
|
@ -418,13 +393,13 @@ where
|
|||
*this.st = State::Stop;
|
||||
continue;
|
||||
}
|
||||
this.inner.state.register_dispatcher(cx.waker());
|
||||
let _ = read.poll_ready(cx);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
// consume request's payload
|
||||
State::ReadPayload => {
|
||||
if this.inner.state.is_io_err() {
|
||||
if !this.inner.state.is_io_open() {
|
||||
*this.st = State::Stop;
|
||||
} else {
|
||||
loop {
|
||||
|
@ -445,7 +420,7 @@ where
|
|||
}
|
||||
// send response body
|
||||
State::SendPayload { ref mut body } => {
|
||||
if this.inner.state.is_io_err() {
|
||||
if !this.inner.state.is_io_open() {
|
||||
*this.st = State::Stop;
|
||||
} else {
|
||||
this.inner.poll_read_payload(cx);
|
||||
|
@ -459,7 +434,7 @@ where
|
|||
this.inner
|
||||
.state
|
||||
.write()
|
||||
.enable_backpressure(Some(cx.waker()));
|
||||
.enable_backpressure(Some(cx));
|
||||
return Poll::Pending;
|
||||
}
|
||||
WritePayloadStatus::Continue => (),
|
||||
|
@ -470,50 +445,48 @@ where
|
|||
}
|
||||
// stop io tasks and call upgrade service
|
||||
State::Upgrade(ref mut req) => {
|
||||
// check if all io tasks have been stopped
|
||||
let io = if Rc::strong_count(this.inner.io.as_ref().unwrap()) == 1 {
|
||||
if let Ok(io) = Rc::try_unwrap(this.inner.io.take().unwrap()) {
|
||||
io.into_inner()
|
||||
} else {
|
||||
return Poll::Ready(Err(DispatchError::InternalError));
|
||||
}
|
||||
} else {
|
||||
// wait next task stop
|
||||
this.inner.state.register_dispatcher(cx.waker());
|
||||
return Poll::Pending;
|
||||
};
|
||||
log::trace!("initate upgrade handling");
|
||||
|
||||
let io = this.inner.io.take().unwrap();
|
||||
let req = req.take().unwrap();
|
||||
*this.st = State::Call;
|
||||
this.inner.state.reset_io_stop();
|
||||
|
||||
// Handle UPGRADE request
|
||||
this.call.set(CallState::Upgrade {
|
||||
fut: this.inner.config.upgrade.as_ref().unwrap().call((
|
||||
req,
|
||||
io,
|
||||
this.inner.state.clone(),
|
||||
this.inner.codec.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
// prepare to shutdown
|
||||
State::Stop => {
|
||||
this.inner.state.shutdown_io();
|
||||
this.inner.unregister_keepalive();
|
||||
if this
|
||||
.inner
|
||||
.io
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.poll_shutdown(cx)?
|
||||
.is_ready()
|
||||
{
|
||||
// get io error
|
||||
if this.inner.error.is_none() {
|
||||
this.inner.error =
|
||||
this.inner.state.take_error().map(DispatchError::Io);
|
||||
}
|
||||
|
||||
// get io error
|
||||
if this.inner.error.is_none() {
|
||||
this.inner.error =
|
||||
this.inner.state.take_io_error().map(DispatchError::Io);
|
||||
}
|
||||
|
||||
return Poll::Ready(if let Some(err) = this.inner.error.take() {
|
||||
Err(err)
|
||||
return Poll::Ready(
|
||||
if let Some(err) = this.inner.error.take() {
|
||||
Err(err)
|
||||
} else {
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
} else {
|
||||
Ok(())
|
||||
});
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -536,8 +509,7 @@ where
|
|||
fn reset_keepalive(&mut self) {
|
||||
// re-register keep-alive
|
||||
if self.flags.contains(Flags::KEEPALIVE) && self.config.keep_alive.non_zero() {
|
||||
let expire = self.config.timer_h1.now()
|
||||
+ std::time::Duration::from(self.config.keep_alive);
|
||||
let expire = now() + time::Duration::from(self.config.keep_alive);
|
||||
if expire != self.expire {
|
||||
self.config
|
||||
.timer_h1
|
||||
|
@ -571,11 +543,11 @@ where
|
|||
}
|
||||
|
||||
fn send_response(&mut self, msg: Response<()>, body: ResponseBody<B>) -> State<B> {
|
||||
trace!("Sending response: {:?} body: {:?}", msg, body.size());
|
||||
trace!("sending response: {:?} body: {:?}", msg, body.size());
|
||||
// we dont need to process responses if socket is disconnected
|
||||
// but we still want to handle requests with app service
|
||||
// so we skip response processing for droppped connection
|
||||
if !self.state.is_io_err() {
|
||||
if self.state.is_io_open() {
|
||||
let result = self
|
||||
.state
|
||||
.write()
|
||||
|
@ -617,7 +589,7 @@ where
|
|||
) -> WritePayloadStatus<B> {
|
||||
match item {
|
||||
Some(Ok(item)) => {
|
||||
trace!("Got response chunk: {:?}", item.len());
|
||||
trace!("got response chunk: {:?}", item.len());
|
||||
match self
|
||||
.state
|
||||
.write()
|
||||
|
@ -637,7 +609,7 @@ where
|
|||
}
|
||||
}
|
||||
None => {
|
||||
trace!("Response payload eof");
|
||||
trace!("response payload eof");
|
||||
if let Err(err) =
|
||||
self.state.write().encode(Message::Chunk(None), &self.codec)
|
||||
{
|
||||
|
@ -653,7 +625,7 @@ where
|
|||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
trace!("Error during response body poll: {:?}", e);
|
||||
trace!("error during response body poll: {:?}", e);
|
||||
self.error = Some(DispatchError::ResponsePayload(e));
|
||||
WritePayloadStatus::Next(State::Stop)
|
||||
}
|
||||
|
@ -686,13 +658,13 @@ where
|
|||
break;
|
||||
}
|
||||
Ok(None) => {
|
||||
if self.state.is_io_err() {
|
||||
if !self.state.is_io_open() {
|
||||
payload.1.set_error(PayloadError::EncodingCorrupted);
|
||||
self.payload = None;
|
||||
self.error = Some(ParseError::Incomplete.into());
|
||||
return ReadPayloadStatus::Dropped;
|
||||
} else {
|
||||
read.wake(cx.waker());
|
||||
let _ = read.poll_ready(cx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -737,6 +709,7 @@ mod tests {
|
|||
use crate::http::config::{DispatcherConfig, ServiceConfig};
|
||||
use crate::http::h1::{ClientCodec, ExpectHandler, UpgradeHandler};
|
||||
use crate::http::{body, Request, ResponseHead, StatusCode};
|
||||
use crate::io::{self as nio, DefaultFilter};
|
||||
use crate::service::{boxed, fn_service, IntoService};
|
||||
use crate::util::{lazy, next, Bytes, BytesMut};
|
||||
use crate::{codec::Decoder, testing::Io, time::sleep, time::Millis};
|
||||
|
@ -747,7 +720,7 @@ mod tests {
|
|||
pub(crate) fn h1<F, S, B>(
|
||||
stream: Io,
|
||||
service: F,
|
||||
) -> Dispatcher<Io, S, B, ExpectHandler, UpgradeHandler<Io>>
|
||||
) -> Dispatcher<DefaultFilter, S, B, ExpectHandler, UpgradeHandler<DefaultFilter>>
|
||||
where
|
||||
F: IntoService<S>,
|
||||
S: Service<Request = Request>,
|
||||
|
@ -756,7 +729,7 @@ mod tests {
|
|||
B: MessageBody,
|
||||
{
|
||||
Dispatcher::new(
|
||||
stream,
|
||||
nio::Io::new(stream),
|
||||
Rc::new(DispatcherConfig::new(
|
||||
ServiceConfig::default(),
|
||||
service.into_service(),
|
||||
|
@ -764,8 +737,6 @@ mod tests {
|
|||
None,
|
||||
None,
|
||||
)),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -777,20 +748,22 @@ mod tests {
|
|||
S::Response: Into<Response<B>>,
|
||||
B: MessageBody + 'static,
|
||||
{
|
||||
crate::rt::spawn(
|
||||
Dispatcher::<Io, S, B, ExpectHandler, UpgradeHandler<Io>>::new(
|
||||
stream,
|
||||
Rc::new(DispatcherConfig::new(
|
||||
ServiceConfig::default(),
|
||||
service.into_service(),
|
||||
ExpectHandler,
|
||||
None,
|
||||
None,
|
||||
)),
|
||||
crate::rt::spawn(Dispatcher::<
|
||||
DefaultFilter,
|
||||
S,
|
||||
B,
|
||||
ExpectHandler,
|
||||
UpgradeHandler<DefaultFilter>,
|
||||
>::new(
|
||||
nio::Io::new(stream),
|
||||
Rc::new(DispatcherConfig::new(
|
||||
ServiceConfig::default(),
|
||||
service.into_service(),
|
||||
ExpectHandler,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
);
|
||||
)),
|
||||
));
|
||||
}
|
||||
|
||||
fn load(decoder: &mut ClientCodec, buf: &mut BytesMut) -> ResponseHead {
|
||||
|
@ -806,7 +779,7 @@ mod tests {
|
|||
let data = Rc::new(Cell::new(false));
|
||||
let data2 = data.clone();
|
||||
let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<Io>>::new(
|
||||
server,
|
||||
nio::Io::new(server),
|
||||
Rc::new(DispatcherConfig::new(
|
||||
ServiceConfig::default(),
|
||||
fn_service(|_| {
|
||||
|
@ -821,8 +794,6 @@ mod tests {
|
|||
},
|
||||
))),
|
||||
)),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
sleep(Millis(50)).await;
|
||||
|
||||
|
|
|
@ -3,35 +3,33 @@ use std::{
|
|||
task,
|
||||
};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite};
|
||||
use crate::framed::State as IoState;
|
||||
use crate::http::body::MessageBody;
|
||||
use crate::http::config::{DispatcherConfig, OnRequest, ServiceConfig};
|
||||
use crate::http::error::{DispatchError, ResponseError};
|
||||
use crate::http::helpers::DataFactory;
|
||||
use crate::http::request::Request;
|
||||
use crate::http::response::Response;
|
||||
use crate::io::{DefaultFilter, Filter, Io, IoRef};
|
||||
use crate::service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
|
||||
use crate::{rt::net::TcpStream, time::Millis, util::Pool};
|
||||
use crate::{time::Millis, util::Pool};
|
||||
|
||||
use super::codec::Codec;
|
||||
use super::dispatcher::Dispatcher;
|
||||
use super::{ExpectHandler, UpgradeHandler};
|
||||
|
||||
/// `ServiceFactory` implementation for HTTP1 transport
|
||||
pub struct H1Service<T, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
|
||||
pub struct H1Service<F, S, B, X = ExpectHandler, U = UpgradeHandler<F>> {
|
||||
srv: S,
|
||||
cfg: ServiceConfig,
|
||||
expect: X,
|
||||
upgrade: Option<U>,
|
||||
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
|
||||
on_request: RefCell<Option<OnRequest<T>>>,
|
||||
on_request: RefCell<Option<OnRequest>>,
|
||||
#[allow(dead_code)]
|
||||
handshake_timeout: Millis,
|
||||
_t: marker::PhantomData<(T, B)>,
|
||||
_t: marker::PhantomData<(F, B)>,
|
||||
}
|
||||
|
||||
impl<T, S, B> H1Service<T, S, B>
|
||||
impl<F, S, B> H1Service<F, S, B>
|
||||
where
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
|
@ -40,15 +38,14 @@ where
|
|||
B: MessageBody,
|
||||
{
|
||||
/// Create new `HttpService` instance with config.
|
||||
pub(crate) fn with_config<F: IntoServiceFactory<S>>(
|
||||
pub(crate) fn with_config<U: IntoServiceFactory<S>>(
|
||||
cfg: ServiceConfig,
|
||||
service: F,
|
||||
service: U,
|
||||
) -> Self {
|
||||
H1Service {
|
||||
srv: service.into_factory(),
|
||||
expect: ExpectHandler,
|
||||
upgrade: None,
|
||||
on_connect: None,
|
||||
on_request: RefCell::new(None),
|
||||
handshake_timeout: cfg.0.ssl_handshake_timeout,
|
||||
_t: marker::PhantomData,
|
||||
|
@ -57,53 +54,14 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<S, B, X, U> H1Service<TcpStream, S, B, X, U>
|
||||
where
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
S::Response: Into<Response<B>>,
|
||||
S::Future: 'static,
|
||||
B: MessageBody,
|
||||
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, TcpStream, IoState, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
{
|
||||
/// Create simple tcp stream service
|
||||
pub fn tcp(
|
||||
self,
|
||||
) -> impl ServiceFactory<
|
||||
Config = (),
|
||||
Request = TcpStream,
|
||||
Response = (),
|
||||
Error = DispatchError,
|
||||
InitError = (),
|
||||
> {
|
||||
pipeline_factory(|io: TcpStream| async move {
|
||||
let peer_addr = io.peer_addr().ok();
|
||||
Ok((io, peer_addr))
|
||||
})
|
||||
.and_then(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "openssl")]
|
||||
mod openssl {
|
||||
use super::*;
|
||||
|
||||
use crate::server::openssl::{Acceptor, SslAcceptor, SslStream};
|
||||
use crate::server::openssl::{Acceptor, SslAcceptor, SslFilter};
|
||||
use crate::server::SslError;
|
||||
|
||||
impl<S, B, X, U> H1Service<SslStream<TcpStream>, S, B, X, U>
|
||||
impl<S, B, X, U> H1Service<SslFilter<DefaultFilter>, S, B, X, U>
|
||||
where
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
|
@ -117,7 +75,7 @@ mod openssl {
|
|||
X::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, SslStream<TcpStream>, IoState, Codec),
|
||||
Request = (Request, Io<SslFilter<DefaultFilter>>, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
|
@ -130,7 +88,7 @@ mod openssl {
|
|||
acceptor: SslAcceptor,
|
||||
) -> impl ServiceFactory<
|
||||
Config = (),
|
||||
Request = TcpStream,
|
||||
Request = Io,
|
||||
Response = (),
|
||||
Error = SslError<DispatchError>,
|
||||
InitError = (),
|
||||
|
@ -141,71 +99,68 @@ mod openssl {
|
|||
.map_err(SslError::Ssl)
|
||||
.map_init_err(|_| panic!()),
|
||||
)
|
||||
.and_then(|io: SslStream<TcpStream>| async move {
|
||||
let peer_addr = io.get_ref().peer_addr().ok();
|
||||
Ok((io, peer_addr))
|
||||
})
|
||||
.and_then(self.map_err(SslError::Service))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls")]
|
||||
mod rustls {
|
||||
use super::*;
|
||||
use crate::server::rustls::{Acceptor, ServerConfig, TlsStream};
|
||||
use crate::server::SslError;
|
||||
use std::fmt;
|
||||
// #[cfg(feature = "rustls")]
|
||||
// mod rustls {
|
||||
// use super::*;
|
||||
// use crate::server::rustls::{Acceptor, ServerConfig, TlsStream};
|
||||
// use crate::server::SslError;
|
||||
// use std::fmt;
|
||||
|
||||
impl<S, B, X, U> H1Service<TlsStream<TcpStream>, S, B, X, U>
|
||||
where
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
S::Response: Into<Response<B>>,
|
||||
S::Future: 'static,
|
||||
B: MessageBody,
|
||||
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, TlsStream<TcpStream>, IoState, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
{
|
||||
/// Create rustls based service
|
||||
pub fn rustls(
|
||||
self,
|
||||
config: ServerConfig,
|
||||
) -> impl ServiceFactory<
|
||||
Config = (),
|
||||
Request = TcpStream,
|
||||
Response = (),
|
||||
Error = SslError<DispatchError>,
|
||||
InitError = (),
|
||||
> {
|
||||
pipeline_factory(
|
||||
Acceptor::new(config)
|
||||
.timeout(self.handshake_timeout)
|
||||
.map_err(SslError::Ssl)
|
||||
.map_init_err(|_| panic!()),
|
||||
)
|
||||
.and_then(|io: TlsStream<TcpStream>| async move {
|
||||
let peer_addr = io.get_ref().0.peer_addr().ok();
|
||||
Ok((io, peer_addr))
|
||||
})
|
||||
.and_then(self.map_err(SslError::Service))
|
||||
}
|
||||
}
|
||||
}
|
||||
// impl<S, B, X, U> H1Service<TlsStream<TcpStream>, S, B, X, U>
|
||||
// where
|
||||
// S: ServiceFactory<Config = (), Request = Request>,
|
||||
// S::Error: ResponseError + 'static,
|
||||
// S::InitError: fmt::Debug,
|
||||
// S::Response: Into<Response<B>>,
|
||||
// S::Future: 'static,
|
||||
// B: MessageBody,
|
||||
// X: ServiceFactory<Config = (), Request = Request, Response = Request>,
|
||||
// X::Error: ResponseError + 'static,
|
||||
// X::InitError: fmt::Debug,
|
||||
// X::Future: 'static,
|
||||
// U: ServiceFactory<
|
||||
// Config = (),
|
||||
// Request = (Request, TlsStream<TcpStream>, IoState, Codec),
|
||||
// Response = (),
|
||||
// >,
|
||||
// U::Error: fmt::Display + Error + 'static,
|
||||
// U::InitError: fmt::Debug,
|
||||
// U::Future: 'static,
|
||||
// {
|
||||
// /// Create rustls based service
|
||||
// pub fn rustls(
|
||||
// self,
|
||||
// config: ServerConfig,
|
||||
// ) -> impl ServiceFactory<
|
||||
// Config = (),
|
||||
// Request = TcpStream,
|
||||
// Response = (),
|
||||
// Error = SslError<DispatchError>,
|
||||
// InitError = (),
|
||||
// > {
|
||||
// pipeline_factory(
|
||||
// Acceptor::new(config)
|
||||
// .timeout(self.handshake_timeout)
|
||||
// .map_err(SslError::Ssl)
|
||||
// .map_init_err(|_| panic!()),
|
||||
// )
|
||||
// .and_then(|io: TlsStream<TcpStream>| async move {
|
||||
// let peer_addr = io.get_ref().0.peer_addr().ok();
|
||||
// Ok((io, peer_addr))
|
||||
// })
|
||||
// .and_then(self.map_err(SslError::Service))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<T, S, B, X, U> H1Service<T, S, B, X, U>
|
||||
impl<F, S, B, X, U> H1Service<F, S, B, X, U>
|
||||
where
|
||||
F: Filter,
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Response: Into<Response<B>>,
|
||||
|
@ -213,7 +168,7 @@ where
|
|||
S::Future: 'static,
|
||||
B: MessageBody,
|
||||
{
|
||||
pub fn expect<X1>(self, expect: X1) -> H1Service<T, S, B, X1, U>
|
||||
pub fn expect<X1>(self, expect: X1) -> H1Service<F, S, B, X1, U>
|
||||
where
|
||||
X1: ServiceFactory<Request = Request, Response = Request>,
|
||||
X1::Error: ResponseError + 'static,
|
||||
|
@ -225,16 +180,15 @@ where
|
|||
cfg: self.cfg,
|
||||
srv: self.srv,
|
||||
upgrade: self.upgrade,
|
||||
on_connect: self.on_connect,
|
||||
on_request: self.on_request,
|
||||
handshake_timeout: self.handshake_timeout,
|
||||
_t: marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<T, S, B, X, U1>
|
||||
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> H1Service<F, S, B, X, U1>
|
||||
where
|
||||
U1: ServiceFactory<Request = (Request, T, IoState, Codec), Response = ()>,
|
||||
U1: ServiceFactory<Request = (Request, Io<F>, Codec), Response = ()>,
|
||||
U1::Error: fmt::Display + Error + 'static,
|
||||
U1::InitError: fmt::Debug,
|
||||
U1::Future: 'static,
|
||||
|
@ -244,34 +198,24 @@ where
|
|||
cfg: self.cfg,
|
||||
srv: self.srv,
|
||||
expect: self.expect,
|
||||
on_connect: self.on_connect,
|
||||
on_request: self.on_request,
|
||||
handshake_timeout: self.handshake_timeout,
|
||||
_t: marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set on connect callback.
|
||||
pub(crate) fn on_connect(
|
||||
mut self,
|
||||
f: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
|
||||
) -> Self {
|
||||
self.on_connect = f;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set req request callback.
|
||||
///
|
||||
/// It get called once per request.
|
||||
pub(crate) fn on_request(self, f: Option<OnRequest<T>>) -> Self {
|
||||
pub(crate) fn on_request(self, f: Option<OnRequest>) -> Self {
|
||||
*self.on_request.borrow_mut() = f;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> ServiceFactory for H1Service<T, S, B, X, U>
|
||||
impl<F, S, B, X, U> ServiceFactory for H1Service<F, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
F: Filter + 'static,
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Response: Into<Response<B>>,
|
||||
|
@ -282,28 +226,23 @@ where
|
|||
X::Error: ResponseError + 'static,
|
||||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, T, IoState, Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U: ServiceFactory<Config = (), Request = (Request, Io<F>, Codec), Response = ()>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
{
|
||||
type Config = ();
|
||||
type Request = (T, Option<net::SocketAddr>);
|
||||
type Request = Io<F>;
|
||||
type Response = ();
|
||||
type Error = DispatchError;
|
||||
type InitError = ();
|
||||
type Service = H1ServiceHandler<T, S::Service, B, X::Service, U::Service>;
|
||||
type Service = H1ServiceHandler<F, S::Service, B, X::Service, U::Service>;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Service, Self::InitError>>>>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
let fut = self.srv.new_service(());
|
||||
let fut_ex = self.expect.new_service(());
|
||||
let fut_upg = self.upgrade.as_ref().map(|f| f.new_service(()));
|
||||
let on_connect = self.on_connect.clone();
|
||||
let on_request = self.on_request.borrow_mut().take();
|
||||
let cfg = self.cfg.clone();
|
||||
|
||||
|
@ -331,7 +270,6 @@ where
|
|||
Ok(H1ServiceHandler {
|
||||
pool,
|
||||
config,
|
||||
on_connect,
|
||||
_t: marker::PhantomData,
|
||||
})
|
||||
})
|
||||
|
@ -339,29 +277,28 @@ where
|
|||
}
|
||||
|
||||
/// `Service` implementation for HTTP1 transport
|
||||
pub struct H1ServiceHandler<T, S: Service, B, X: Service, U: Service> {
|
||||
pub struct H1ServiceHandler<F, S: Service, B, X: Service, U: Service> {
|
||||
pool: Pool,
|
||||
config: Rc<DispatcherConfig<T, S, X, U>>,
|
||||
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
|
||||
_t: marker::PhantomData<(T, B)>,
|
||||
config: Rc<DispatcherConfig<S, X, U>>,
|
||||
_t: marker::PhantomData<(F, B)>,
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> Service for H1ServiceHandler<T, S, B, X, U>
|
||||
impl<F, S, B, X, U> Service for H1ServiceHandler<F, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
F: Filter + 'static,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Response: Into<Response<B>>,
|
||||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
U: Service<Request = (Request, T, IoState, Codec), Response = ()>,
|
||||
U: Service<Request = (Request, Io<F>, Codec), Response = ()>,
|
||||
U::Error: fmt::Display + Error + 'static,
|
||||
{
|
||||
type Request = (T, Option<net::SocketAddr>);
|
||||
type Request = Io<F>;
|
||||
type Response = ();
|
||||
type Error = DispatchError;
|
||||
type Future = Dispatcher<T, S, B, X, U>;
|
||||
type Future = Dispatcher<F, S, B, X, U>;
|
||||
|
||||
fn poll_ready(
|
||||
&self,
|
||||
|
@ -429,9 +366,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn call(&self, (io, addr): Self::Request) -> Self::Future {
|
||||
let on_connect = self.on_connect.as_ref().map(|f| f(&io));
|
||||
|
||||
Dispatcher::new(io, self.config.clone(), addr, on_connect)
|
||||
fn call(&self, io: Self::Request) -> Self::Future {
|
||||
Dispatcher::new(io, self.config.clone())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,16 +2,17 @@ use std::{io, marker::PhantomData, task::Context, task::Poll};
|
|||
|
||||
use crate::http::h1::Codec;
|
||||
use crate::http::request::Request;
|
||||
use crate::{framed::State, util::Ready, Service, ServiceFactory};
|
||||
use crate::io::Io;
|
||||
use crate::{util::Ready, Service, ServiceFactory};
|
||||
|
||||
pub struct UpgradeHandler<T>(PhantomData<T>);
|
||||
pub struct UpgradeHandler<F>(PhantomData<F>);
|
||||
|
||||
impl<T> ServiceFactory for UpgradeHandler<T> {
|
||||
impl<F> ServiceFactory for UpgradeHandler<F> {
|
||||
type Config = ();
|
||||
type Request = (Request, T, State, Codec);
|
||||
type Request = (Request, Io<F>, Codec);
|
||||
type Response = ();
|
||||
type Error = io::Error;
|
||||
type Service = UpgradeHandler<T>;
|
||||
type Service = UpgradeHandler<F>;
|
||||
type InitError = io::Error;
|
||||
type Future = Ready<Self::Service, Self::InitError>;
|
||||
|
||||
|
@ -21,8 +22,8 @@ impl<T> ServiceFactory for UpgradeHandler<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> Service for UpgradeHandler<T> {
|
||||
type Request = (Request, T, State, Codec);
|
||||
impl<F> Service for UpgradeHandler<F> {
|
||||
type Request = (Request, Io<F>, Codec);
|
||||
type Response = ();
|
||||
type Error = io::Error;
|
||||
type Future = Ready<Self::Response, Self::Error>;
|
||||
|
|
|
@ -4,11 +4,11 @@ use std::task::{Context, Poll};
|
|||
|
||||
use h2::RecvStream;
|
||||
|
||||
mod dispatcher;
|
||||
mod service;
|
||||
//mod dispatcher;
|
||||
//mod service;
|
||||
|
||||
pub use self::dispatcher::Dispatcher;
|
||||
pub use self::service::H2Service;
|
||||
//pub use self::dispatcher::Dispatcher;
|
||||
//pub use self::service::H2Service;
|
||||
use crate::{http::error::PayloadError, util::Bytes, Stream};
|
||||
|
||||
/// H2 receive stream
|
||||
|
|
|
@ -6,6 +6,7 @@ use bitflags::bitflags;
|
|||
|
||||
use crate::http::header::HeaderMap;
|
||||
use crate::http::{header, Method, StatusCode, Uri, Version};
|
||||
use crate::io::IoRef;
|
||||
use crate::util::Extensions;
|
||||
|
||||
/// Represents various types of connection
|
||||
|
@ -45,19 +46,19 @@ pub struct RequestHead {
|
|||
pub version: Version,
|
||||
pub headers: HeaderMap,
|
||||
pub extensions: RefCell<Extensions>,
|
||||
pub peer_addr: Option<net::SocketAddr>,
|
||||
pub io: Option<IoRef>,
|
||||
pub(super) flags: Flags,
|
||||
}
|
||||
|
||||
impl Default for RequestHead {
|
||||
fn default() -> RequestHead {
|
||||
RequestHead {
|
||||
io: None,
|
||||
uri: Uri::default(),
|
||||
method: Method::default(),
|
||||
version: Version::HTTP_11,
|
||||
headers: HeaderMap::with_capacity(16),
|
||||
flags: Flags::empty(),
|
||||
peer_addr: None,
|
||||
extensions: RefCell::new(Extensions::new()),
|
||||
}
|
||||
}
|
||||
|
@ -65,6 +66,7 @@ impl Default for RequestHead {
|
|||
|
||||
impl Head for RequestHead {
|
||||
fn clear(&mut self) {
|
||||
self.io = None;
|
||||
self.flags = Flags::empty();
|
||||
self.headers.clear();
|
||||
self.extensions.get_mut().clear();
|
||||
|
|
|
@ -6,6 +6,7 @@ use crate::http::header::HeaderMap;
|
|||
use crate::http::httpmessage::HttpMessage;
|
||||
use crate::http::message::{Message, RequestHead};
|
||||
use crate::http::payload::Payload;
|
||||
use crate::io::IoRef;
|
||||
use crate::util::Extensions;
|
||||
|
||||
/// Request
|
||||
|
@ -126,13 +127,21 @@ impl Request {
|
|||
self.head().method == Method::CONNECT
|
||||
}
|
||||
|
||||
/// Io reference for current connection
|
||||
#[inline]
|
||||
pub fn io(&self) -> Option<&IoRef> {
|
||||
self.head().io.as_ref()
|
||||
}
|
||||
|
||||
/// Peer socket address
|
||||
///
|
||||
/// Peer address is actual socket address, if proxy is used in front of
|
||||
/// ntex http server, then peer address would be address of this proxy.
|
||||
#[inline]
|
||||
pub fn peer_addr(&self) -> Option<net::SocketAddr> {
|
||||
self.head().peer_addr
|
||||
// TODO! fix
|
||||
// self.head().peer_addr
|
||||
None
|
||||
}
|
||||
|
||||
/// Get request's payload
|
||||
|
|
|
@ -6,7 +6,7 @@ use std::{
|
|||
use h2::server::{self, Handshake};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite};
|
||||
use crate::framed::State;
|
||||
use crate::io::{DefaultFilter, Filter, Io, IoRef};
|
||||
use crate::rt::net::TcpStream;
|
||||
use crate::service::{pipeline_factory, IntoServiceFactory, Service, ServiceFactory};
|
||||
use crate::time::{Millis, Seconds};
|
||||
|
@ -19,20 +19,20 @@ use super::error::{DispatchError, ResponseError};
|
|||
use super::helpers::DataFactory;
|
||||
use super::request::Request;
|
||||
use super::response::Response;
|
||||
use super::{h1, h2::Dispatcher, Protocol};
|
||||
//use super::{h1, h2::Dispatcher, Protocol};
|
||||
use super::{h1, Protocol};
|
||||
|
||||
/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation
|
||||
pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> {
|
||||
pub struct HttpService<F, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<F>> {
|
||||
srv: S,
|
||||
cfg: ServiceConfig,
|
||||
expect: X,
|
||||
upgrade: Option<U>,
|
||||
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
|
||||
on_request: cell::RefCell<Option<OnRequest<T>>>,
|
||||
_t: marker::PhantomData<(T, B)>,
|
||||
on_request: cell::RefCell<Option<OnRequest>>,
|
||||
_t: marker::PhantomData<(F, B)>,
|
||||
}
|
||||
|
||||
impl<T, S, B> HttpService<T, S, B>
|
||||
impl<F, S, B> HttpService<F, S, B>
|
||||
where
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
|
@ -43,13 +43,14 @@ where
|
|||
B: MessageBody + 'static,
|
||||
{
|
||||
/// Create builder for `HttpService` instance.
|
||||
pub fn build() -> HttpServiceBuilder<T, S> {
|
||||
pub fn build() -> HttpServiceBuilder<F, S> {
|
||||
HttpServiceBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, B> HttpService<T, S, B>
|
||||
impl<F, S, B> HttpService<F, S, B>
|
||||
where
|
||||
F: Filter,
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
|
@ -59,7 +60,7 @@ where
|
|||
B: MessageBody + 'static,
|
||||
{
|
||||
/// Create new `HttpService` instance.
|
||||
pub fn new<F: IntoServiceFactory<S>>(service: F) -> Self {
|
||||
pub fn new<U: IntoServiceFactory<S>>(service: U) -> Self {
|
||||
let cfg = ServiceConfig::new(
|
||||
KeepAlive::Timeout(Seconds(5)),
|
||||
Millis(5_000),
|
||||
|
@ -73,31 +74,30 @@ where
|
|||
srv: service.into_factory(),
|
||||
expect: h1::ExpectHandler,
|
||||
upgrade: None,
|
||||
on_connect: None,
|
||||
on_request: cell::RefCell::new(None),
|
||||
_t: marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create new `HttpService` instance with config.
|
||||
pub(crate) fn with_config<F: IntoServiceFactory<S>>(
|
||||
pub(crate) fn with_config<U: IntoServiceFactory<S>>(
|
||||
cfg: ServiceConfig,
|
||||
service: F,
|
||||
service: U,
|
||||
) -> Self {
|
||||
HttpService {
|
||||
cfg,
|
||||
srv: service.into_factory(),
|
||||
expect: h1::ExpectHandler,
|
||||
upgrade: None,
|
||||
on_connect: None,
|
||||
on_request: cell::RefCell::new(None),
|
||||
_t: marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> HttpService<T, S, B, X, U>
|
||||
impl<F, S, B, X, U> HttpService<F, S, B, X, U>
|
||||
where
|
||||
F: Filter,
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
|
@ -111,7 +111,7 @@ where
|
|||
/// Service get called with request that contains `EXPECT` header.
|
||||
/// Service must return request in case of success, in that case
|
||||
/// request will be forwarded to main service.
|
||||
pub fn expect<X1>(self, expect: X1) -> HttpService<T, S, B, X1, U>
|
||||
pub fn expect<X1>(self, expect: X1) -> HttpService<F, S, B, X1, U>
|
||||
where
|
||||
X1: ServiceFactory<Config = (), Request = Request, Response = Request>,
|
||||
X1::Error: ResponseError,
|
||||
|
@ -123,7 +123,6 @@ where
|
|||
cfg: self.cfg,
|
||||
srv: self.srv,
|
||||
upgrade: self.upgrade,
|
||||
on_connect: self.on_connect,
|
||||
on_request: self.on_request,
|
||||
_t: marker::PhantomData,
|
||||
}
|
||||
|
@ -133,11 +132,11 @@ where
|
|||
///
|
||||
/// If service is provided then normal requests handling get halted
|
||||
/// and this service get called with original request and framed object.
|
||||
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<T, S, B, X, U1>
|
||||
pub fn upgrade<U1>(self, upgrade: Option<U1>) -> HttpService<F, S, B, X, U1>
|
||||
where
|
||||
U1: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, T, State, h1::Codec),
|
||||
Request = (Request, Io<F>, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U1::Error: fmt::Display + error::Error + 'static,
|
||||
|
@ -149,77 +148,25 @@ where
|
|||
cfg: self.cfg,
|
||||
srv: self.srv,
|
||||
expect: self.expect,
|
||||
on_connect: self.on_connect,
|
||||
on_request: self.on_request,
|
||||
_t: marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set on connect callback.
|
||||
pub(crate) fn on_connect(
|
||||
mut self,
|
||||
f: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
|
||||
) -> Self {
|
||||
self.on_connect = f;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set on request callback.
|
||||
pub(crate) fn on_request(self, f: Option<OnRequest<T>>) -> Self {
|
||||
pub(crate) fn on_request(self, f: Option<OnRequest>) -> Self {
|
||||
*self.on_request.borrow_mut() = f;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, B, X, U> HttpService<TcpStream, S, B, X, U>
|
||||
where
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
S::Response: Into<Response<B>> + 'static,
|
||||
S::Future: 'static,
|
||||
<S::Service as Service>::Future: 'static,
|
||||
B: MessageBody + 'static,
|
||||
X: ServiceFactory<Config = (), Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, TcpStream, State, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
<U::Service as Service>::Future: 'static,
|
||||
{
|
||||
/// Create simple tcp stream service
|
||||
pub fn tcp(
|
||||
self,
|
||||
) -> impl ServiceFactory<
|
||||
Config = (),
|
||||
Request = TcpStream,
|
||||
Response = (),
|
||||
Error = DispatchError,
|
||||
InitError = (),
|
||||
> {
|
||||
pipeline_factory(|io: TcpStream| async move {
|
||||
let peer_addr = io.peer_addr().ok();
|
||||
Ok((io, Protocol::Http1, peer_addr))
|
||||
})
|
||||
.and_then(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "openssl")]
|
||||
mod openssl {
|
||||
use super::*;
|
||||
use crate::server::openssl::{Acceptor, SslAcceptor, SslStream};
|
||||
use crate::server::openssl::{Acceptor, SslAcceptor, SslFilter};
|
||||
use crate::server::SslError;
|
||||
|
||||
impl<S, B, X, U> HttpService<SslStream<TcpStream>, S, B, X, U>
|
||||
impl<S, B, X, U> HttpService<SslFilter<DefaultFilter>, S, B, X, U>
|
||||
where
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
|
@ -235,7 +182,7 @@ mod openssl {
|
|||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, SslStream<TcpStream>, State, h1::Codec),
|
||||
Request = (Request, Io<SslFilter<DefaultFilter>>, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
|
@ -249,7 +196,7 @@ mod openssl {
|
|||
acceptor: SslAcceptor,
|
||||
) -> impl ServiceFactory<
|
||||
Config = (),
|
||||
Request = TcpStream,
|
||||
Request = Io<DefaultFilter>,
|
||||
Response = (),
|
||||
Error = SslError<DispatchError>,
|
||||
InitError = (),
|
||||
|
@ -260,19 +207,6 @@ mod openssl {
|
|||
.map_err(SslError::Ssl)
|
||||
.map_init_err(|_| panic!()),
|
||||
)
|
||||
.and_then(|io: SslStream<TcpStream>| async move {
|
||||
let proto = if let Some(protos) = io.ssl().selected_alpn_protocol() {
|
||||
if protos.windows(2).any(|window| window == b"h2") {
|
||||
Protocol::Http2
|
||||
} else {
|
||||
Protocol::Http1
|
||||
}
|
||||
} else {
|
||||
Protocol::Http1
|
||||
};
|
||||
let peer_addr = io.get_ref().peer_addr().ok();
|
||||
Ok((io, proto, peer_addr))
|
||||
})
|
||||
.and_then(self.map_err(SslError::Service))
|
||||
}
|
||||
}
|
||||
|
@ -284,8 +218,9 @@ mod rustls {
|
|||
use crate::server::rustls::{Acceptor, ServerConfig, TlsStream};
|
||||
use crate::server::SslError;
|
||||
|
||||
impl<S, B, X, U> HttpService<TlsStream<TcpStream>, S, B, X, U>
|
||||
impl<F, S, B, X, U> HttpService<F, S, B, X, U>
|
||||
where
|
||||
F: Filter,
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
|
@ -300,7 +235,7 @@ mod rustls {
|
|||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, TlsStream<TcpStream>, State, h1::Codec),
|
||||
Request = (Request, Io<F>, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
|
@ -311,47 +246,49 @@ mod rustls {
|
|||
/// Create openssl based service
|
||||
pub fn rustls(
|
||||
self,
|
||||
mut config: ServerConfig,
|
||||
config: ServerConfig,
|
||||
) -> impl ServiceFactory<
|
||||
Config = (),
|
||||
Request = TcpStream,
|
||||
Request = Io<F>,
|
||||
Response = (),
|
||||
Error = SslError<DispatchError>,
|
||||
//Error = SslError<DispatchError>,
|
||||
Error = DispatchError,
|
||||
InitError = (),
|
||||
> {
|
||||
let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()];
|
||||
config.alpn_protocols = protos;
|
||||
self
|
||||
// let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()];
|
||||
// config.alpn_protocols = protos;
|
||||
|
||||
pipeline_factory(
|
||||
Acceptor::new(config)
|
||||
.timeout(self.cfg.0.ssl_handshake_timeout)
|
||||
.map_err(SslError::Ssl)
|
||||
.map_init_err(|_| panic!()),
|
||||
)
|
||||
.and_then(|io: TlsStream<TcpStream>| async move {
|
||||
let proto = io
|
||||
.get_ref()
|
||||
.1
|
||||
.alpn_protocol()
|
||||
.and_then(|protos| {
|
||||
if protos.windows(2).any(|window| window == b"h2") {
|
||||
Some(Protocol::Http2)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(Protocol::Http1);
|
||||
let peer_addr = io.get_ref().0.peer_addr().ok();
|
||||
Ok((io, proto, peer_addr))
|
||||
})
|
||||
.and_then(self.map_err(SslError::Service))
|
||||
// pipeline_factory(
|
||||
// Acceptor::new(config)
|
||||
// .timeout(self.cfg.0.ssl_handshake_timeout)
|
||||
// .map_err(SslError::Ssl)
|
||||
// .map_init_err(|_| panic!()),
|
||||
// )
|
||||
// .and_then(|io: TlsStream<TcpStream>| async move {
|
||||
// let proto = io
|
||||
// .get_ref()
|
||||
// .1
|
||||
// .alpn_protocol()
|
||||
// .and_then(|protos| {
|
||||
// if protos.windows(2).any(|window| window == b"h2") {
|
||||
// Some(Protocol::Http2)
|
||||
// } else {
|
||||
// None
|
||||
// }
|
||||
// })
|
||||
// .unwrap_or(Protocol::Http1);
|
||||
// let peer_addr = io.get_ref().0.peer_addr().ok();
|
||||
// Ok((io, proto, peer_addr))
|
||||
// })
|
||||
// .and_then(self.map_err(SslError::Service))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> ServiceFactory for HttpService<T, S, B, X, U>
|
||||
impl<F, S, B, X, U> ServiceFactory for HttpService<F, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
F: Filter + 'static,
|
||||
S: ServiceFactory<Config = (), Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::InitError: fmt::Debug,
|
||||
|
@ -364,29 +301,24 @@ where
|
|||
X::InitError: fmt::Debug,
|
||||
X::Future: 'static,
|
||||
<X::Service as Service>::Future: 'static,
|
||||
U: ServiceFactory<
|
||||
Config = (),
|
||||
Request = (Request, T, State, h1::Codec),
|
||||
Response = (),
|
||||
>,
|
||||
U: ServiceFactory<Config = (), Request = (Request, Io<F>, h1::Codec), Response = ()>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
U::InitError: fmt::Debug,
|
||||
U::Future: 'static,
|
||||
<U::Service as Service>::Future: 'static,
|
||||
{
|
||||
type Config = ();
|
||||
type Request = (T, Protocol, Option<net::SocketAddr>);
|
||||
type Request = Io<F>;
|
||||
type Response = ();
|
||||
type Error = DispatchError;
|
||||
type InitError = ();
|
||||
type Service = HttpServiceHandler<T, S::Service, B, X::Service, U::Service>;
|
||||
type Service = HttpServiceHandler<F, S::Service, B, X::Service, U::Service>;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Service, Self::InitError>>>>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
let fut = self.srv.new_service(());
|
||||
let fut_ex = self.expect.new_service(());
|
||||
let fut_upg = self.upgrade.as_ref().map(|f| f.new_service(()));
|
||||
let on_connect = self.on_connect.clone();
|
||||
let on_request = self.on_request.borrow_mut().take();
|
||||
let cfg = self.cfg.clone();
|
||||
|
||||
|
@ -414,7 +346,6 @@ where
|
|||
|
||||
Ok(HttpServiceHandler {
|
||||
pool,
|
||||
on_connect,
|
||||
config: Rc::new(config),
|
||||
_t: marker::PhantomData,
|
||||
})
|
||||
|
@ -423,16 +354,15 @@ where
|
|||
}
|
||||
|
||||
/// `Service` implementation for http transport
|
||||
pub struct HttpServiceHandler<T, S: Service, B, X: Service, U: Service> {
|
||||
pub struct HttpServiceHandler<F, S: Service, B, X: Service, U: Service> {
|
||||
pool: Pool,
|
||||
config: Rc<DispatcherConfig<T, S, X, U>>,
|
||||
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
|
||||
_t: marker::PhantomData<(T, B, X)>,
|
||||
config: Rc<DispatcherConfig<S, X, U>>,
|
||||
_t: marker::PhantomData<(F, B, X)>,
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> Service for HttpServiceHandler<T, S, B, X, U>
|
||||
impl<F, S, B, X, U> Service for HttpServiceHandler<F, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
F: Filter + 'static,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Future: 'static,
|
||||
|
@ -440,13 +370,13 @@ where
|
|||
B: MessageBody + 'static,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
|
||||
U: Service<Request = (Request, Io<F>, h1::Codec), Response = ()>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
{
|
||||
type Request = (T, Protocol, Option<net::SocketAddr>);
|
||||
type Request = Io<F>;
|
||||
type Response = ();
|
||||
type Error = DispatchError;
|
||||
type Future = HttpServiceHandlerResponse<T, S, B, X, U>;
|
||||
type Future = HttpServiceHandlerResponse<F, S, B, X, U>;
|
||||
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
let cfg = self.config.as_ref();
|
||||
|
@ -507,46 +437,36 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn call(&self, (io, proto, peer_addr): Self::Request) -> Self::Future {
|
||||
log::trace!(
|
||||
"New http connection protocol {:?} peer address {:?}",
|
||||
proto,
|
||||
peer_addr
|
||||
);
|
||||
let on_connect = self.on_connect.as_ref().map(|f| f(&io));
|
||||
fn call(&self, io: Self::Request) -> Self::Future {
|
||||
// log::trace!("New http connection protocol {:?}", proto);
|
||||
|
||||
match proto {
|
||||
Protocol::Http2 => HttpServiceHandlerResponse {
|
||||
state: ResponseState::H2Handshake {
|
||||
data: Some((
|
||||
server::Builder::new().handshake(io),
|
||||
self.config.clone(),
|
||||
on_connect,
|
||||
peer_addr,
|
||||
)),
|
||||
},
|
||||
},
|
||||
Protocol::Http1 => HttpServiceHandlerResponse {
|
||||
state: ResponseState::H1 {
|
||||
fut: h1::Dispatcher::new(
|
||||
io,
|
||||
self.config.clone(),
|
||||
peer_addr,
|
||||
on_connect,
|
||||
),
|
||||
},
|
||||
//match proto {
|
||||
//Protocol::Http2 => todo!(),
|
||||
// HttpServiceHandlerResponse {
|
||||
// state: ResponseState::H2Handshake {
|
||||
// data: Some((
|
||||
// server::Builder::new().handshake(io),
|
||||
// self.config.clone(),
|
||||
// on_connect,
|
||||
// peer_addr,
|
||||
// )),
|
||||
// },
|
||||
// },
|
||||
// Protocol::Http1 =>
|
||||
HttpServiceHandlerResponse {
|
||||
state: ResponseState::H1 {
|
||||
fut: h1::Dispatcher::new(io, self.config.clone()),
|
||||
},
|
||||
// },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct HttpServiceHandlerResponse<T, S, B, X, U>
|
||||
pub struct HttpServiceHandlerResponse<F, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead,
|
||||
T: AsyncWrite,
|
||||
T: Unpin,
|
||||
T: 'static,
|
||||
F: Filter,
|
||||
F: 'static,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError,
|
||||
S::Error: 'static,
|
||||
|
@ -557,52 +477,50 @@ pin_project_lite::pin_project! {
|
|||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError,
|
||||
X::Error: 'static,
|
||||
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
|
||||
U: Service<Request = (Request, Io<F>, h1::Codec), Response = ()>,
|
||||
U::Error: fmt::Display,
|
||||
U::Error: error::Error,
|
||||
U::Error: 'static,
|
||||
{
|
||||
#[pin]
|
||||
state: ResponseState<T, S, B, X, U>,
|
||||
state: ResponseState<F, S, B, X, U>,
|
||||
}
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[project = StateProject]
|
||||
enum ResponseState<T, S, B, X, U>
|
||||
enum ResponseState<F, S, B, X, U>
|
||||
where
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError,
|
||||
S::Error: 'static,
|
||||
T: AsyncRead,
|
||||
T: AsyncWrite,
|
||||
T: Unpin,
|
||||
T: 'static,
|
||||
F: Filter,
|
||||
F: 'static,
|
||||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError,
|
||||
X::Error: 'static,
|
||||
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
|
||||
U: Service<Request = (Request, Io<F>, h1::Codec), Response = ()>,
|
||||
U::Error: fmt::Display,
|
||||
U::Error: error::Error,
|
||||
U::Error: 'static,
|
||||
{
|
||||
H1 { #[pin] fut: h1::Dispatcher<T, S, B, X, U> },
|
||||
H2 { fut: Dispatcher<T, S, B, X, U> },
|
||||
H2Handshake { data:
|
||||
Option<(
|
||||
Handshake<T, Bytes>,
|
||||
Rc<DispatcherConfig<T, S, X, U>>,
|
||||
Option<Box<dyn DataFactory>>,
|
||||
Option<net::SocketAddr>,
|
||||
)>,
|
||||
},
|
||||
H1 { #[pin] fut: h1::Dispatcher<F, S, B, X, U> },
|
||||
// H2 { fut: Dispatcher<F, S, B, X, U> },
|
||||
// H2Handshake { data:
|
||||
// Option<(
|
||||
// Handshake<T, Bytes>,
|
||||
// Rc<DispatcherConfig<S, X, U>>,
|
||||
// Option<Box<dyn DataFactory>>,
|
||||
// Option<net::SocketAddr>,
|
||||
// )>,
|
||||
// },
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S, B, X, U> Future for HttpServiceHandlerResponse<T, S, B, X, U>
|
||||
impl<F, S, B, X, U> Future for HttpServiceHandlerResponse<F, S, B, X, U>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
F: Filter + 'static,
|
||||
S: Service<Request = Request>,
|
||||
S::Error: ResponseError + 'static,
|
||||
S::Future: 'static,
|
||||
|
@ -610,7 +528,7 @@ where
|
|||
B: MessageBody,
|
||||
X: Service<Request = Request, Response = Request>,
|
||||
X::Error: ResponseError + 'static,
|
||||
U: Service<Request = (Request, T, State, h1::Codec), Response = ()>,
|
||||
U: Service<Request = (Request, Io<F>, h1::Codec), Response = ()>,
|
||||
U::Error: fmt::Display + error::Error + 'static,
|
||||
{
|
||||
type Output = Result<(), DispatchError>;
|
||||
|
@ -620,26 +538,26 @@ where
|
|||
|
||||
match this.state.project() {
|
||||
StateProject::H1 { fut } => fut.poll(cx),
|
||||
StateProject::H2 { ref mut fut } => Pin::new(fut).poll(cx),
|
||||
StateProject::H2Handshake { data } => {
|
||||
let conn = if let Some(ref mut item) = data {
|
||||
match Pin::new(&mut item.0).poll(cx) {
|
||||
Poll::Ready(Ok(conn)) => conn,
|
||||
Poll::Ready(Err(err)) => {
|
||||
trace!("H2 handshake error: {}", err);
|
||||
return Poll::Ready(Err(err.into()));
|
||||
}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
} else {
|
||||
panic!()
|
||||
};
|
||||
let (_, cfg, on_connect, peer_addr) = data.take().unwrap();
|
||||
self.as_mut().project().state.set(ResponseState::H2 {
|
||||
fut: Dispatcher::new(cfg, conn, on_connect, None, peer_addr),
|
||||
});
|
||||
self.poll(cx)
|
||||
}
|
||||
// StateProject::H2 { ref mut fut } => Pin::new(fut).poll(cx),
|
||||
// StateProject::H2Handshake { data } => {
|
||||
// let conn = if let Some(ref mut item) = data {
|
||||
// match Pin::new(&mut item.0).poll(cx) {
|
||||
// Poll::Ready(Ok(conn)) => conn,
|
||||
// Poll::Ready(Err(err)) => {
|
||||
// trace!("H2 handshake error: {}", err);
|
||||
// return Poll::Ready(Err(err.into()));
|
||||
// }
|
||||
// Poll::Pending => return Poll::Pending,
|
||||
// }
|
||||
// } else {
|
||||
// panic!()
|
||||
// };
|
||||
// let (_, cfg, on_connect, peer_addr) = data.take().unwrap();
|
||||
// self.as_mut().project().state.set(ResponseState::H2 {
|
||||
// fut: Dispatcher::new(cfg, conn, on_connect, None, peer_addr),
|
||||
// });
|
||||
// self.poll(cx)
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,12 +5,15 @@ use std::{convert::TryFrom, io, net, str::FromStr, sync::mpsc, thread};
|
|||
use coo_kie::{Cookie, CookieJar};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite, Framed};
|
||||
use crate::io::IoBoxed;
|
||||
use crate::rt::{net::TcpStream, System};
|
||||
use crate::server::{Server, StreamServiceFactory};
|
||||
use crate::{time::Millis, time::Seconds, util::Bytes};
|
||||
|
||||
use super::client::error::WsClientError;
|
||||
use super::client::{Client, ClientRequest, ClientResponse, Connector};
|
||||
use super::client::{
|
||||
ws::WsConnection, Client, ClientRequest, ClientResponse, Connector,
|
||||
};
|
||||
use super::error::{HttpError, PayloadError};
|
||||
use super::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use super::payload::Payload;
|
||||
|
@ -207,7 +210,7 @@ fn parts(parts: &mut Option<Inner>) -> &mut Inner {
|
|||
/// assert!(response.status().is_success());
|
||||
/// }
|
||||
/// ```
|
||||
pub fn server<F: StreamServiceFactory<TcpStream>>(factory: F) -> TestServer {
|
||||
pub fn server<F: StreamServiceFactory>(factory: F) -> TestServer {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
// run server in separate thread
|
||||
|
@ -318,21 +321,12 @@ impl TestServer {
|
|||
}
|
||||
|
||||
/// Connect to websocket server at a given path
|
||||
pub async fn ws_at(
|
||||
&mut self,
|
||||
path: &str,
|
||||
) -> Result<Framed<impl AsyncRead + AsyncWrite, crate::ws::Codec>, WsClientError>
|
||||
{
|
||||
let url = self.url(path);
|
||||
let connect = self.client.ws(url).connect();
|
||||
connect.await.map(|ws| ws.into_inner().1)
|
||||
pub async fn ws_at(&mut self, path: &str) -> Result<WsConnection, WsClientError> {
|
||||
self.client.ws(self.url(path)).connect().await
|
||||
}
|
||||
|
||||
/// Connect to a websocket server
|
||||
pub async fn ws(
|
||||
&mut self,
|
||||
) -> Result<Framed<impl AsyncRead + AsyncWrite, crate::ws::Codec>, WsClientError>
|
||||
{
|
||||
pub async fn ws(&mut self) -> Result<WsConnection, WsClientError> {
|
||||
self.ws_at("/").await
|
||||
}
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@
|
|||
//! * `compress` - enables compression support in http and web modules
|
||||
//! * `cookie` - enables cookie support in http and web modules
|
||||
|
||||
#![warn(
|
||||
rust_2018_idioms,
|
||||
unreachable_pub,
|
||||
// missing_debug_implementations,
|
||||
// missing_docs,
|
||||
)]
|
||||
//#![warn(
|
||||
// rust_2018_idioms,
|
||||
// unreachable_pub,
|
||||
// missing_debug_implementations,
|
||||
// missing_docs,
|
||||
//)]
|
||||
#![allow(
|
||||
type_alias_bounds,
|
||||
clippy::type_complexity,
|
||||
|
@ -21,6 +21,7 @@
|
|||
clippy::too_many_arguments,
|
||||
clippy::new_without_default
|
||||
)]
|
||||
#![allow(unused_imports)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
@ -35,7 +36,7 @@ pub(crate) use ntex_macros::rt_test2 as rt_test;
|
|||
|
||||
pub mod channel;
|
||||
pub mod connect;
|
||||
pub mod framed;
|
||||
//pub mod framed;
|
||||
#[cfg(feature = "http-framework")]
|
||||
pub mod http;
|
||||
pub mod server;
|
||||
|
|
|
@ -193,7 +193,7 @@ impl ServerBuilder {
|
|||
factory: F,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
F: StreamServiceFactory<TcpStream>,
|
||||
F: StreamServiceFactory,
|
||||
U: net::ToSocketAddrs,
|
||||
{
|
||||
let sockets = bind_addr(addr, self.backlog)?;
|
||||
|
@ -219,7 +219,7 @@ impl ServerBuilder {
|
|||
/// Add new unix domain service to the server.
|
||||
pub fn bind_uds<F, U, N>(self, name: N, addr: U, factory: F) -> io::Result<Self>
|
||||
where
|
||||
F: StreamServiceFactory<crate::rt::net::UnixStream>,
|
||||
F: StreamServiceFactory,
|
||||
N: AsRef<str>,
|
||||
U: AsRef<std::path::Path>,
|
||||
{
|
||||
|
@ -249,7 +249,7 @@ impl ServerBuilder {
|
|||
factory: F,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
F: StreamServiceFactory<crate::rt::net::UnixStream>,
|
||||
F: StreamServiceFactory,
|
||||
{
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
let token = self.token.next();
|
||||
|
@ -273,7 +273,7 @@ impl ServerBuilder {
|
|||
factory: F,
|
||||
) -> io::Result<Self>
|
||||
where
|
||||
F: StreamServiceFactory<TcpStream>,
|
||||
F: StreamServiceFactory,
|
||||
{
|
||||
let token = self.token.next();
|
||||
self.services.push(Factory::create(
|
||||
|
|
|
@ -6,8 +6,8 @@ use std::{
|
|||
use log::error;
|
||||
|
||||
use crate::rt::net::TcpStream;
|
||||
use crate::service;
|
||||
use crate::util::{counter::CounterGuard, HashMap, Ready};
|
||||
use crate::{io::Io, service};
|
||||
|
||||
use super::builder::bind_addr;
|
||||
use super::service::{
|
||||
|
@ -199,7 +199,7 @@ impl InternalServiceFactory for ConfiguredService {
|
|||
res.push((
|
||||
token,
|
||||
Box::new(StreamService::new(service::fn_service(
|
||||
move |_: TcpStream| {
|
||||
move |_: Io| {
|
||||
error!("Service {:?} is not configured", name);
|
||||
Ready::<_, ()>::Ok(())
|
||||
},
|
||||
|
@ -292,7 +292,7 @@ impl ServiceRuntime {
|
|||
pub fn service<T, F>(&self, name: &str, service: F)
|
||||
where
|
||||
F: service::IntoServiceFactory<T>,
|
||||
T: service::ServiceFactory<Config = (), Request = TcpStream> + 'static,
|
||||
T: service::ServiceFactory<Config = (), Request = Io> + 'static,
|
||||
T::Future: 'static,
|
||||
T::Service: 'static,
|
||||
T::InitError: fmt::Debug,
|
||||
|
@ -338,7 +338,7 @@ struct ServiceFactory<T> {
|
|||
|
||||
impl<T> service::ServiceFactory for ServiceFactory<T>
|
||||
where
|
||||
T: service::ServiceFactory<Config = (), Request = TcpStream>,
|
||||
T: service::ServiceFactory<Config = (), Request = Io>,
|
||||
T::Future: 'static,
|
||||
T::Service: 'static,
|
||||
T::Error: 'static,
|
||||
|
|
|
@ -30,9 +30,6 @@ pub use self::config::{ServiceConfig, ServiceRuntime};
|
|||
pub use self::service::StreamServiceFactory;
|
||||
pub use self::test::{build_test_server, test_server, TestServer};
|
||||
|
||||
#[doc(hidden)]
|
||||
pub use self::socket::FromStream;
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
/// Server readiness status
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use std::task::{Context, Poll};
|
||||
use std::{error::Error, fmt, future::Future, io, marker::PhantomData, pin::Pin};
|
||||
|
||||
pub use ntex_openssl::SslFilter;
|
||||
pub use open_ssl::ssl::{self, AlpnError, Ssl, SslAcceptor, SslAcceptorBuilder};
|
||||
pub use tokio_openssl::SslStream;
|
||||
|
||||
use ntex_openssl::SslAcceptor as IoSslAcceptor;
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite};
|
||||
use crate::io::{Filter, FilterFactory, Io};
|
||||
use crate::service::{Service, ServiceFactory};
|
||||
use crate::time::{sleep, Millis, Sleep};
|
||||
use crate::util::{counter::Counter, counter::CounterGuard, Ready};
|
||||
|
@ -14,19 +17,17 @@ use super::MAX_SSL_ACCEPT_COUNTER;
|
|||
/// Support `TLS` server connections via openssl package
|
||||
///
|
||||
/// `openssl` feature enables `Acceptor` type
|
||||
pub struct Acceptor<T: AsyncRead + AsyncWrite> {
|
||||
acceptor: SslAcceptor,
|
||||
timeout: Millis,
|
||||
io: PhantomData<T>,
|
||||
pub struct Acceptor<F> {
|
||||
acceptor: IoSslAcceptor,
|
||||
_t: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite> Acceptor<T> {
|
||||
impl<F> Acceptor<F> {
|
||||
/// Create default openssl acceptor service
|
||||
pub fn new(acceptor: SslAcceptor) -> Self {
|
||||
Acceptor {
|
||||
acceptor,
|
||||
timeout: Millis(5_000),
|
||||
io: PhantomData,
|
||||
acceptor: IoSslAcceptor::new(acceptor),
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,30 +35,26 @@ impl<T: AsyncRead + AsyncWrite> Acceptor<T> {
|
|||
///
|
||||
/// Default is set to 5 seconds.
|
||||
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
|
||||
self.timeout = timeout.into();
|
||||
self.acceptor.timeout(timeout);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite> Clone for Acceptor<T> {
|
||||
impl<F> Clone for Acceptor<F> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
acceptor: self.acceptor.clone(),
|
||||
timeout: self.timeout,
|
||||
io: PhantomData,
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ServiceFactory for Acceptor<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
|
||||
{
|
||||
type Request = T;
|
||||
type Response = SslStream<T>;
|
||||
impl<F: Filter + 'static> ServiceFactory for Acceptor<F> {
|
||||
type Request = Io<F>;
|
||||
type Response = Io<SslFilter<F>>;
|
||||
type Error = Box<dyn Error>;
|
||||
type Config = ();
|
||||
type Service = AcceptorService<T>;
|
||||
type Service = AcceptorService<F>;
|
||||
type InitError = ();
|
||||
type Future = Ready<Self::Service, Self::InitError>;
|
||||
|
||||
|
@ -66,28 +63,23 @@ where
|
|||
Ready::Ok(AcceptorService {
|
||||
acceptor: self.acceptor.clone(),
|
||||
conns: conns.priv_clone(),
|
||||
timeout: self.timeout,
|
||||
io: PhantomData,
|
||||
_t: PhantomData,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AcceptorService<T> {
|
||||
acceptor: SslAcceptor,
|
||||
pub struct AcceptorService<F> {
|
||||
acceptor: IoSslAcceptor,
|
||||
conns: Counter,
|
||||
timeout: Millis,
|
||||
io: PhantomData<T>,
|
||||
_t: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<T> Service for AcceptorService<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin + fmt::Debug + 'static,
|
||||
{
|
||||
type Request = T;
|
||||
type Response = SslStream<T>;
|
||||
impl<F: Filter + 'static> Service for AcceptorService<F> {
|
||||
type Request = Io<F>;
|
||||
type Response = Io<SslFilter<F>>;
|
||||
type Error = Box<dyn Error>;
|
||||
type Future = AcceptorServiceResponse<T>;
|
||||
type Future = AcceptorServiceResponse<F>;
|
||||
|
||||
#[inline]
|
||||
fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
|
@ -100,57 +92,29 @@ where
|
|||
|
||||
#[inline]
|
||||
fn call(&self, req: Self::Request) -> Self::Future {
|
||||
let ssl = Ssl::new(self.acceptor.context())
|
||||
.expect("Provided SSL acceptor was invalid.");
|
||||
AcceptorServiceResponse {
|
||||
_guard: self.conns.get(),
|
||||
io: None,
|
||||
delay: self.timeout.map(sleep),
|
||||
io_factory: Some(SslStream::new(ssl, req)),
|
||||
fut: self.acceptor.clone().create(req),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AcceptorServiceResponse<T>
|
||||
where
|
||||
T: AsyncRead,
|
||||
T: AsyncWrite,
|
||||
{
|
||||
io: Option<SslStream<T>>,
|
||||
delay: Option<Sleep>,
|
||||
io_factory: Option<Result<SslStream<T>, open_ssl::error::ErrorStack>>,
|
||||
_guard: CounterGuard,
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + Unpin> Future for AcceptorServiceResponse<T> {
|
||||
type Output = Result<SslStream<T>, Box<dyn Error>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.as_mut();
|
||||
|
||||
if let Some(ref delay) = this.delay {
|
||||
match delay.poll_elapsed(cx) {
|
||||
Poll::Pending => (),
|
||||
Poll::Ready(_) => {
|
||||
return Poll::Ready(Err(Box::new(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"ssl handshake timeout",
|
||||
))))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match this.io_factory.take() {
|
||||
Some(Ok(io)) => this.io = Some(io),
|
||||
Some(Err(err)) => return Poll::Ready(Err(Box::new(err))),
|
||||
None => (),
|
||||
}
|
||||
|
||||
let io = this.io.as_mut().unwrap();
|
||||
match Pin::new(io).poll_accept(cx) {
|
||||
Poll::Ready(Ok(_)) => Poll::Ready(Ok(this.io.take().unwrap())),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(Box::new(e))),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
pin_project_lite::pin_project! {
|
||||
pub struct AcceptorServiceResponse<F>
|
||||
where
|
||||
F: Filter,
|
||||
F: 'static,
|
||||
{
|
||||
#[pin]
|
||||
fut: <IoSslAcceptor as FilterFactory<F>>::Future,
|
||||
_guard: CounterGuard,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter + 'static> Future for AcceptorServiceResponse<F> {
|
||||
type Output = Result<Io<SslFilter<F>>, Box<dyn Error>>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.project().fut.poll(cx)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use std::convert::TryInto;
|
||||
use std::{
|
||||
future::Future, marker::PhantomData, net::SocketAddr, pin::Pin, task::Context,
|
||||
task::Poll,
|
||||
|
@ -5,12 +6,12 @@ use std::{
|
|||
|
||||
use log::error;
|
||||
|
||||
use crate::io::Io;
|
||||
use crate::service::{Service, ServiceFactory};
|
||||
use crate::util::{counter::CounterGuard, Ready};
|
||||
use crate::{rt::spawn, time::Millis};
|
||||
|
||||
use super::socket::{FromStream, Stream};
|
||||
use super::Token;
|
||||
use super::{socket::Stream, Token};
|
||||
|
||||
/// Server message
|
||||
pub(super) enum ServerMessage {
|
||||
|
@ -22,8 +23,8 @@ pub(super) enum ServerMessage {
|
|||
ForceShutdown,
|
||||
}
|
||||
|
||||
pub trait StreamServiceFactory<Stream: FromStream>: Send + Clone + 'static {
|
||||
type Factory: ServiceFactory<Config = (), Request = Stream>;
|
||||
pub trait StreamServiceFactory: Send + Clone + 'static {
|
||||
type Factory: ServiceFactory<Config = (), Request = Io>;
|
||||
|
||||
fn create(&self) -> Self::Factory;
|
||||
}
|
||||
|
@ -57,12 +58,11 @@ impl<T> StreamService<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T, I> Service for StreamService<T>
|
||||
impl<T> Service for StreamService<T>
|
||||
where
|
||||
T: Service<Request = I>,
|
||||
T: Service<Request = Io>,
|
||||
T::Future: 'static,
|
||||
T::Error: 'static,
|
||||
I: FromStream,
|
||||
{
|
||||
type Request = (Option<CounterGuard>, ServerMessage);
|
||||
type Response = ();
|
||||
|
@ -82,7 +82,7 @@ where
|
|||
fn call(&self, (guard, req): (Option<CounterGuard>, ServerMessage)) -> Self::Future {
|
||||
match req {
|
||||
ServerMessage::Connect(stream) => {
|
||||
let stream = FromStream::from_stream(stream).map_err(|e| {
|
||||
let stream = stream.try_into().map_err(|e| {
|
||||
error!("Cannot convert to an async io stream: {}", e);
|
||||
});
|
||||
|
||||
|
@ -102,18 +102,16 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub(super) struct Factory<F: StreamServiceFactory<Io>, Io: FromStream> {
|
||||
pub(super) struct Factory<F: StreamServiceFactory> {
|
||||
name: String,
|
||||
inner: F,
|
||||
token: Token,
|
||||
addr: SocketAddr,
|
||||
_t: PhantomData<Io>,
|
||||
}
|
||||
|
||||
impl<F, Io> Factory<F, Io>
|
||||
impl<F> Factory<F>
|
||||
where
|
||||
F: StreamServiceFactory<Io>,
|
||||
Io: FromStream + Send + 'static,
|
||||
F: StreamServiceFactory,
|
||||
{
|
||||
pub(crate) fn create(
|
||||
name: String,
|
||||
|
@ -126,15 +124,13 @@ where
|
|||
token,
|
||||
inner,
|
||||
addr,
|
||||
_t: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, Io> InternalServiceFactory for Factory<F, Io>
|
||||
impl<F> InternalServiceFactory for Factory<F>
|
||||
where
|
||||
F: StreamServiceFactory<Io>,
|
||||
Io: FromStream + Send + 'static,
|
||||
F: StreamServiceFactory,
|
||||
{
|
||||
fn name(&self, _: Token) -> &str {
|
||||
&self.name
|
||||
|
@ -146,7 +142,6 @@ where
|
|||
inner: self.inner.clone(),
|
||||
token: self.token,
|
||||
addr: self.addr,
|
||||
_t: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -187,11 +182,10 @@ impl InternalServiceFactory for Box<dyn InternalServiceFactory> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<F, T, I> StreamServiceFactory<I> for F
|
||||
impl<F, T> StreamServiceFactory for F
|
||||
where
|
||||
F: Fn() -> T + Send + Clone + 'static,
|
||||
T: ServiceFactory<Config = (), Request = I>,
|
||||
I: FromStream,
|
||||
T: ServiceFactory<Config = (), Request = Io>,
|
||||
{
|
||||
type Factory = T;
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use std::{fmt, io, net};
|
||||
use std::{convert::TryFrom, fmt, io, net};
|
||||
|
||||
use crate::codec::{AsyncRead, AsyncWrite};
|
||||
use crate::io::{Io, IoStream};
|
||||
use crate::rt::net::TcpStream;
|
||||
|
||||
pub(crate) enum Listener {
|
||||
|
@ -146,32 +147,29 @@ pub enum Stream {
|
|||
Uds(mio::net::UnixStream),
|
||||
}
|
||||
|
||||
pub trait FromStream: AsyncRead + AsyncWrite + Sized {
|
||||
fn from_stream(stream: Stream) -> io::Result<Self>;
|
||||
}
|
||||
impl TryFrom<Stream> for Io {
|
||||
type Error = io::Error;
|
||||
|
||||
#[cfg(unix)]
|
||||
impl FromStream for TcpStream {
|
||||
fn from_stream(sock: Stream) -> io::Result<Self> {
|
||||
fn try_from(sock: Stream) -> Result<Self, Self::Error> {
|
||||
#[cfg(unix)]
|
||||
match sock {
|
||||
Stream::Tcp(stream) => {
|
||||
use std::os::unix::io::{FromRawFd, IntoRawFd};
|
||||
let fd = IntoRawFd::into_raw_fd(stream);
|
||||
let io = TcpStream::from_std(unsafe { FromRawFd::from_raw_fd(fd) })?;
|
||||
io.set_nodelay(true)?;
|
||||
Ok(io)
|
||||
Ok(Io::new(io))
|
||||
}
|
||||
#[cfg(unix)]
|
||||
Stream::Uds(_) => {
|
||||
panic!("Should not happen, bug in server impl");
|
||||
Stream::Uds(stream) => {
|
||||
use crate::rt::net::UnixStream;
|
||||
use std::os::unix::io::{FromRawFd, IntoRawFd};
|
||||
let fd = IntoRawFd::into_raw_fd(stream);
|
||||
let ud = UnixStream::from_std(unsafe { FromRawFd::from_raw_fd(fd) });
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
impl FromStream for TcpStream {
|
||||
fn from_stream(sock: Stream) -> io::Result<Self> {
|
||||
#[cfg(windows)]
|
||||
match sock {
|
||||
Stream::Tcp(stream) => {
|
||||
use std::os::windows::io::{FromRawSocket, IntoRawSocket};
|
||||
|
@ -179,26 +177,7 @@ impl FromStream for TcpStream {
|
|||
let io =
|
||||
TcpStream::from_std(unsafe { FromRawSocket::from_raw_socket(fd) })?;
|
||||
io.set_nodelay(true)?;
|
||||
Ok(io)
|
||||
}
|
||||
#[cfg(unix)]
|
||||
Stream::Uds(_) => {
|
||||
panic!("Should not happen, bug in server impl");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl FromStream for crate::rt::net::UnixStream {
|
||||
fn from_stream(sock: Stream) -> io::Result<Self> {
|
||||
match sock {
|
||||
Stream::Tcp(_) => panic!("Should not happen, bug in server impl"),
|
||||
Stream::Uds(stream) => {
|
||||
use crate::rt::net::UnixStream;
|
||||
use std::os::unix::io::{FromRawFd, IntoRawFd};
|
||||
let fd = IntoRawFd::into_raw_fd(stream);
|
||||
UnixStream::from_std(unsafe { FromRawFd::from_raw_fd(fd) })
|
||||
Ok(Io::new(io))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ use crate::server::{Server, ServerBuilder, StreamServiceFactory};
|
|||
/// assert!(response.status().is_success());
|
||||
/// }
|
||||
/// ```
|
||||
pub fn test_server<F: StreamServiceFactory<TcpStream>>(factory: F) -> TestServer {
|
||||
pub fn test_server<F: StreamServiceFactory>(factory: F) -> TestServer {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
// run server in separate thread
|
||||
|
|
|
@ -3,6 +3,7 @@ use std::{cell::Ref, cell::RefCell, cell::RefMut, fmt, net, rc::Rc};
|
|||
use crate::http::{
|
||||
HeaderMap, HttpMessage, Message, Method, Payload, RequestHead, Uri, Version,
|
||||
};
|
||||
use crate::io::IoRef;
|
||||
use crate::router::Path;
|
||||
use crate::util::{Extensions, Ready};
|
||||
|
||||
|
@ -105,6 +106,12 @@ impl HttpRequest {
|
|||
}
|
||||
}
|
||||
|
||||
/// Io reference for current connection
|
||||
#[inline]
|
||||
pub fn io(&self) -> Option<&IoRef> {
|
||||
self.head().io.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the Path parameters.
|
||||
///
|
||||
/// Params is a container for url parameters.
|
||||
|
@ -183,17 +190,6 @@ impl HttpRequest {
|
|||
&self.0.rmap
|
||||
}
|
||||
|
||||
/// Peer socket address
|
||||
///
|
||||
/// Peer address is actual socket address, if proxy is used in front of
|
||||
/// ntex http server, then peer address would be address of this proxy.
|
||||
///
|
||||
/// To get client connection information `.connection_info()` should be used.
|
||||
#[inline]
|
||||
pub fn peer_addr(&self) -> Option<net::SocketAddr> {
|
||||
self.head().peer_addr
|
||||
}
|
||||
|
||||
/// Get *ConnectionInfo* for the current request.
|
||||
///
|
||||
/// This method panics if request's extensions container is already
|
||||
|
|
|
@ -119,7 +119,9 @@ impl ConnectionInfo {
|
|||
}
|
||||
if remote.is_none() {
|
||||
// get peeraddr from socketaddr
|
||||
peer = req.peer_addr.map(|addr| format!("{}", addr));
|
||||
|
||||
// TODO! fix
|
||||
// peer = req.peer_addr.map(|addr| format!("{}", addr));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ use std::{fmt, net};
|
|||
use crate::http::{
|
||||
header, HeaderMap, HttpMessage, Method, Payload, RequestHead, Response, Uri, Version,
|
||||
};
|
||||
use crate::io::IoRef;
|
||||
use crate::router::{Path, Resource};
|
||||
use crate::util::Extensions;
|
||||
|
||||
|
@ -87,6 +88,12 @@ impl<Err> WebRequest<Err> {
|
|||
WebResponse::new(res.into(), self.req)
|
||||
}
|
||||
|
||||
/// Io reference for current connection
|
||||
#[inline]
|
||||
pub fn io(&self) -> Option<&IoRef> {
|
||||
self.head().io.as_ref()
|
||||
}
|
||||
|
||||
/// This method returns reference to the request head
|
||||
#[inline]
|
||||
pub fn head(&self) -> &RequestHead {
|
||||
|
@ -147,17 +154,6 @@ impl<Err> WebRequest<Err> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Peer socket address
|
||||
///
|
||||
/// Peer address is actual socket address, if proxy is used in front of
|
||||
/// ntex http server, then peer address would be address of this proxy.
|
||||
///
|
||||
/// To get client connection information `ConnectionInfo` should be used.
|
||||
#[inline]
|
||||
pub fn peer_addr(&self) -> Option<net::SocketAddr> {
|
||||
self.head().peer_addr
|
||||
}
|
||||
|
||||
/// Get *ConnectionInfo* for the current request.
|
||||
#[inline]
|
||||
pub fn connection_info(&self) -> Ref<'_, ConnectionInfo> {
|
||||
|
|
|
@ -2,8 +2,8 @@ use std::{fmt, io, marker::PhantomData, net, sync::Arc, sync::Mutex};
|
|||
|
||||
#[cfg(feature = "openssl")]
|
||||
use crate::server::openssl::{AlpnError, SslAcceptor, SslAcceptorBuilder};
|
||||
#[cfg(feature = "rustls")]
|
||||
use crate::server::rustls::ServerConfig as RustlsServerConfig;
|
||||
//#[cfg(feature = "rustls")]
|
||||
//use crate::server::rustls::ServerConfig as RustlsServerConfig;
|
||||
|
||||
#[cfg(unix)]
|
||||
use crate::http::Protocol;
|
||||
|
@ -275,7 +275,6 @@ where
|
|||
.disconnect_timeout(c.client_disconnect)
|
||||
.memory_pool(c.pool)
|
||||
.finish(map_config(factory(), move |_| cfg.clone()))
|
||||
.tcp()
|
||||
},
|
||||
)?;
|
||||
Ok(self)
|
||||
|
@ -326,50 +325,50 @@ where
|
|||
Ok(self)
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls")]
|
||||
/// Use listener for accepting incoming tls connection requests
|
||||
///
|
||||
/// This method sets alpn protocols to "h2" and "http/1.1"
|
||||
pub fn listen_rustls(
|
||||
self,
|
||||
lst: net::TcpListener,
|
||||
config: RustlsServerConfig,
|
||||
) -> io::Result<Self> {
|
||||
self.listen_rustls_inner(lst, config)
|
||||
}
|
||||
// #[cfg(feature = "rustls")]
|
||||
// /// Use listener for accepting incoming tls connection requests
|
||||
// ///
|
||||
// /// This method sets alpn protocols to "h2" and "http/1.1"
|
||||
// pub fn listen_rustls(
|
||||
// self,
|
||||
// lst: net::TcpListener,
|
||||
// config: RustlsServerConfig,
|
||||
// ) -> io::Result<Self> {
|
||||
// self.listen_rustls_inner(lst, config)
|
||||
// }
|
||||
|
||||
#[cfg(feature = "rustls")]
|
||||
fn listen_rustls_inner(
|
||||
mut self,
|
||||
lst: net::TcpListener,
|
||||
config: RustlsServerConfig,
|
||||
) -> io::Result<Self> {
|
||||
let factory = self.factory.clone();
|
||||
let cfg = self.config.clone();
|
||||
let addr = lst.local_addr().unwrap();
|
||||
// #[cfg(feature = "rustls")]
|
||||
// fn listen_rustls_inner(
|
||||
// mut self,
|
||||
// lst: net::TcpListener,
|
||||
// config: RustlsServerConfig,
|
||||
// ) -> io::Result<Self> {
|
||||
// let factory = self.factory.clone();
|
||||
// let cfg = self.config.clone();
|
||||
// let addr = lst.local_addr().unwrap();
|
||||
|
||||
self.builder = self.builder.listen(
|
||||
format!("ntex-web-rustls-service-{}", addr),
|
||||
lst,
|
||||
move || {
|
||||
let c = cfg.lock().unwrap();
|
||||
let cfg = AppConfig::new(
|
||||
true,
|
||||
addr,
|
||||
c.host.clone().unwrap_or_else(|| format!("{}", addr)),
|
||||
);
|
||||
HttpService::build()
|
||||
.keep_alive(c.keep_alive)
|
||||
.client_timeout(c.client_timeout)
|
||||
.disconnect_timeout(c.client_disconnect)
|
||||
.ssl_handshake_timeout(c.handshake_timeout)
|
||||
.memory_pool(c.pool)
|
||||
.finish(map_config(factory(), move |_| cfg.clone()))
|
||||
.rustls(config.clone())
|
||||
},
|
||||
)?;
|
||||
Ok(self)
|
||||
}
|
||||
// self.builder = self.builder.listen(
|
||||
// format!("ntex-web-rustls-service-{}", addr),
|
||||
// lst,
|
||||
// move || {
|
||||
// let c = cfg.lock().unwrap();
|
||||
// let cfg = AppConfig::new(
|
||||
// true,
|
||||
// addr,
|
||||
// c.host.clone().unwrap_or_else(|| format!("{}", addr)),
|
||||
// );
|
||||
// HttpService::build()
|
||||
// .keep_alive(c.keep_alive)
|
||||
// .client_timeout(c.client_timeout)
|
||||
// .disconnect_timeout(c.client_disconnect)
|
||||
// .ssl_handshake_timeout(c.handshake_timeout)
|
||||
// .memory_pool(c.pool)
|
||||
// .finish(map_config(factory(), move |_| cfg.clone()))
|
||||
// .rustls(config.clone())
|
||||
// },
|
||||
// )?;
|
||||
// Ok(self)
|
||||
// }
|
||||
|
||||
/// The socket address to bind
|
||||
///
|
||||
|
@ -437,21 +436,21 @@ where
|
|||
Ok(self)
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls")]
|
||||
/// Start listening for incoming tls connections.
|
||||
///
|
||||
/// This method sets alpn protocols to "h2" and "http/1.1"
|
||||
pub fn bind_rustls<A: net::ToSocketAddrs>(
|
||||
mut self,
|
||||
addr: A,
|
||||
config: RustlsServerConfig,
|
||||
) -> io::Result<Self> {
|
||||
let sockets = self.bind2(addr)?;
|
||||
for lst in sockets {
|
||||
self = self.listen_rustls_inner(lst, config.clone())?;
|
||||
}
|
||||
Ok(self)
|
||||
}
|
||||
// #[cfg(feature = "rustls")]
|
||||
// /// Start listening for incoming tls connections.
|
||||
// ///
|
||||
// /// This method sets alpn protocols to "h2" and "http/1.1"
|
||||
// pub fn bind_rustls<A: net::ToSocketAddrs>(
|
||||
// mut self,
|
||||
// addr: A,
|
||||
// config: RustlsServerConfig,
|
||||
// ) -> io::Result<Self> {
|
||||
// let sockets = self.bind2(addr)?;
|
||||
// for lst in sockets {
|
||||
// self = self.listen_rustls_inner(lst, config.clone())?;
|
||||
// }
|
||||
// Ok(self)
|
||||
// }
|
||||
|
||||
#[cfg(unix)]
|
||||
/// Start listening for unix domain connections on existing listener.
|
||||
|
@ -479,16 +478,11 @@ where
|
|||
socket_addr,
|
||||
c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)),
|
||||
);
|
||||
pipeline_factory(|io: UnixStream| {
|
||||
crate::util::Ready::Ok((io, Protocol::Http1, None))
|
||||
})
|
||||
.and_then(
|
||||
HttpService::build()
|
||||
.keep_alive(c.keep_alive)
|
||||
.client_timeout(c.client_timeout)
|
||||
.memory_pool(c.pool)
|
||||
.finish(map_config(factory(), move |_| config.clone())),
|
||||
)
|
||||
HttpService::build()
|
||||
.keep_alive(c.keep_alive)
|
||||
.client_timeout(c.client_timeout)
|
||||
.memory_pool(c.pool)
|
||||
.finish(map_config(factory(), move |_| config.clone()))
|
||||
})?;
|
||||
Ok(self)
|
||||
}
|
||||
|
@ -520,16 +514,11 @@ where
|
|||
socket_addr,
|
||||
c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)),
|
||||
);
|
||||
pipeline_factory(|io: UnixStream| {
|
||||
crate::util::Ready::Ok((io, Protocol::Http1, None))
|
||||
})
|
||||
.and_then(
|
||||
HttpService::build()
|
||||
.keep_alive(c.keep_alive)
|
||||
.client_timeout(c.client_timeout)
|
||||
.memory_pool(c.pool)
|
||||
.finish(map_config(factory(), move |_| config.clone())),
|
||||
)
|
||||
HttpService::build()
|
||||
.keep_alive(c.keep_alive)
|
||||
.client_timeout(c.client_timeout)
|
||||
.memory_pool(c.pool)
|
||||
.finish(map_config(factory(), move |_| config.clone()))
|
||||
},
|
||||
)?;
|
||||
Ok(self)
|
||||
|
|
|
@ -465,15 +465,12 @@ impl TestRequest {
|
|||
|
||||
/// Complete request creation and generate `Request` instance
|
||||
pub fn to_request(mut self) -> Request {
|
||||
let mut req = self.req.finish();
|
||||
req.head_mut().peer_addr = self.peer_addr;
|
||||
req
|
||||
self.req.finish()
|
||||
}
|
||||
|
||||
/// Complete request creation and generate `WebRequest` instance
|
||||
pub fn to_srv_request(mut self) -> WebRequest<DefaultError> {
|
||||
let (mut head, payload) = self.req.finish().into_parts();
|
||||
head.peer_addr = self.peer_addr;
|
||||
let (head, payload) = self.req.finish().into_parts();
|
||||
*self.path.get_mut() = head.uri.clone();
|
||||
|
||||
WebRequest::new(HttpRequest::new(
|
||||
|
@ -494,8 +491,7 @@ impl TestRequest {
|
|||
|
||||
/// Complete request creation and generate `HttpRequest` instance
|
||||
pub fn to_http_request(mut self) -> HttpRequest {
|
||||
let (mut head, payload) = self.req.finish().into_parts();
|
||||
head.peer_addr = self.peer_addr;
|
||||
let (head, payload) = self.req.finish().into_parts();
|
||||
*self.path.get_mut() = head.uri.clone();
|
||||
|
||||
HttpRequest::new(
|
||||
|
@ -511,8 +507,7 @@ impl TestRequest {
|
|||
|
||||
/// Complete request creation and generate `HttpRequest` and `Payload` instances
|
||||
pub fn to_http_parts(mut self) -> (HttpRequest, Payload) {
|
||||
let (mut head, payload) = self.req.finish().into_parts();
|
||||
head.peer_addr = self.peer_addr;
|
||||
let (head, payload) = self.req.finish().into_parts();
|
||||
*self.path.get_mut() = head.uri.clone();
|
||||
|
||||
let req = HttpRequest::new(
|
||||
|
@ -636,7 +631,6 @@ where
|
|||
HttpService::build()
|
||||
.client_timeout(ctimeout)
|
||||
.h1(map_config(factory(), move |_| cfg.clone()))
|
||||
.tcp()
|
||||
}),
|
||||
HttpVer::Http2 => builder.listen("test", tcp, move || {
|
||||
let cfg =
|
||||
|
@ -644,7 +638,6 @@ where
|
|||
HttpService::build()
|
||||
.client_timeout(ctimeout)
|
||||
.h2(map_config(factory(), move |_| cfg.clone()))
|
||||
.tcp()
|
||||
}),
|
||||
HttpVer::Both => builder.listen("test", tcp, move || {
|
||||
let cfg =
|
||||
|
@ -652,7 +645,6 @@ where
|
|||
HttpService::build()
|
||||
.client_timeout(ctimeout)
|
||||
.finish(map_config(factory(), move |_| cfg.clone()))
|
||||
.tcp()
|
||||
}),
|
||||
},
|
||||
#[cfg(feature = "openssl")]
|
||||
|
@ -842,8 +834,9 @@ impl TestServerConfig {
|
|||
/// Start rustls server
|
||||
#[cfg(feature = "rustls")]
|
||||
pub fn rustls(mut self, config: rust_tls::ServerConfig) -> Self {
|
||||
self.stream = StreamType::Rustls(config);
|
||||
self
|
||||
// self.stream = StreamType::Rustls(config);
|
||||
// self
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
/// Set server client timeout in seconds for first request.
|
||||
|
@ -928,19 +921,12 @@ impl TestServer {
|
|||
}
|
||||
|
||||
/// Connect to websocket server at a given path
|
||||
pub async fn ws_at(
|
||||
&self,
|
||||
path: &str,
|
||||
) -> Result<ws::WsConnection<impl AsyncRead + AsyncWrite>, WsClientError> {
|
||||
let url = self.url(path);
|
||||
let connect = self.client.ws(url).connect();
|
||||
connect.await
|
||||
pub async fn ws_at(&self, path: &str) -> Result<ws::WsConnection, WsClientError> {
|
||||
self.client.ws(self.url(path)).connect().await
|
||||
}
|
||||
|
||||
/// Connect to a websocket server
|
||||
pub async fn ws(
|
||||
&self,
|
||||
) -> Result<ws::WsConnection<impl AsyncRead + AsyncWrite>, WsClientError> {
|
||||
pub async fn ws(&self) -> Result<ws::WsConnection, WsClientError> {
|
||||
self.ws_at("/").await
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ pub use self::stream::{StreamDecoder, StreamEncoder};
|
|||
pub enum WsError<E> {
|
||||
Service(E),
|
||||
KeepAlive,
|
||||
Disconnected,
|
||||
Protocol(ProtocolError),
|
||||
Io(io::Error),
|
||||
}
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
use std::{future::Future, rc::Rc};
|
||||
|
||||
use crate::framed::{OnDisconnect, State};
|
||||
use crate::io::{Io, IoRef, OnDisconnect};
|
||||
use crate::ws;
|
||||
|
||||
pub struct WsSink(Rc<WsSinkInner>);
|
||||
|
||||
struct WsSinkInner {
|
||||
state: State,
|
||||
io: IoRef,
|
||||
codec: ws::Codec,
|
||||
}
|
||||
|
||||
impl WsSink {
|
||||
pub(crate) fn new(state: State, codec: ws::Codec) -> Self {
|
||||
Self(Rc::new(WsSinkInner { state, codec }))
|
||||
pub(crate) fn new(io: IoRef, codec: ws::Codec) -> Self {
|
||||
Self(Rc::new(WsSinkInner { io, codec }))
|
||||
}
|
||||
|
||||
/// Endcode and send message to the peer.
|
||||
|
@ -23,13 +23,13 @@ impl WsSink {
|
|||
let inner = self.0.clone();
|
||||
|
||||
async move {
|
||||
inner.state.write().encode(item, &inner.codec)?;
|
||||
inner.io.write().encode(item, &inner.codec)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Notify when connection get disconnected
|
||||
pub fn on_disconnect(&self) -> OnDisconnect {
|
||||
self.0.state.on_disconnect()
|
||||
self.0.io.on_disconnect()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -215,15 +215,12 @@ async fn test_connection_reuse() {
|
|||
num2.fetch_add(1, Ordering::Relaxed);
|
||||
ok(io)
|
||||
})
|
||||
.and_then(
|
||||
HttpService::new(map_config(
|
||||
App::new().service(
|
||||
web::resource("/").route(web::to(|| async { HttpResponse::Ok() })),
|
||||
),
|
||||
|_| AppConfig::default(),
|
||||
))
|
||||
.tcp(),
|
||||
)
|
||||
.and_then(HttpService::new(map_config(
|
||||
App::new().service(
|
||||
web::resource("/").route(web::to(|| async { HttpResponse::Ok() })),
|
||||
),
|
||||
|_| AppConfig::default(),
|
||||
)))
|
||||
});
|
||||
|
||||
let client = Client::build().timeout(Seconds(10)).finish();
|
||||
|
@ -253,15 +250,12 @@ async fn test_connection_force_close() {
|
|||
num2.fetch_add(1, Ordering::Relaxed);
|
||||
ok(io)
|
||||
})
|
||||
.and_then(
|
||||
HttpService::new(map_config(
|
||||
App::new().service(
|
||||
web::resource("/").route(web::to(|| async { HttpResponse::Ok() })),
|
||||
),
|
||||
|_| AppConfig::default(),
|
||||
))
|
||||
.tcp(),
|
||||
)
|
||||
.and_then(HttpService::new(map_config(
|
||||
App::new().service(
|
||||
web::resource("/").route(web::to(|| async { HttpResponse::Ok() })),
|
||||
),
|
||||
|_| AppConfig::default(),
|
||||
)))
|
||||
});
|
||||
|
||||
let client = Client::build().timeout(Seconds(10)).finish();
|
||||
|
@ -291,15 +285,12 @@ async fn test_connection_server_close() {
|
|||
num2.fetch_add(1, Ordering::Relaxed);
|
||||
ok(io)
|
||||
})
|
||||
.and_then(
|
||||
HttpService::new(map_config(
|
||||
App::new().service(web::resource("/").route(web::to(|| async {
|
||||
HttpResponse::Ok().force_close().finish()
|
||||
}))),
|
||||
|_| AppConfig::default(),
|
||||
))
|
||||
.tcp(),
|
||||
)
|
||||
.and_then(HttpService::new(map_config(
|
||||
App::new().service(web::resource("/").route(web::to(|| async {
|
||||
HttpResponse::Ok().force_close().finish()
|
||||
}))),
|
||||
|_| AppConfig::default(),
|
||||
)))
|
||||
});
|
||||
|
||||
let client = Client::build().timeout(Seconds(10)).finish();
|
||||
|
@ -329,16 +320,13 @@ async fn test_connection_wait_queue() {
|
|||
num2.fetch_add(1, Ordering::Relaxed);
|
||||
ok(io)
|
||||
})
|
||||
.and_then(
|
||||
HttpService::new(map_config(
|
||||
App::new().service(
|
||||
web::resource("/")
|
||||
.route(web::to(|| async { HttpResponse::Ok().body(STR) })),
|
||||
),
|
||||
|_| AppConfig::default(),
|
||||
))
|
||||
.tcp(),
|
||||
)
|
||||
.and_then(HttpService::new(map_config(
|
||||
App::new().service(
|
||||
web::resource("/")
|
||||
.route(web::to(|| async { HttpResponse::Ok().body(STR) })),
|
||||
),
|
||||
|_| AppConfig::default(),
|
||||
)))
|
||||
});
|
||||
|
||||
let client = Client::build()
|
||||
|
@ -378,15 +366,12 @@ async fn test_connection_wait_queue_force_close() {
|
|||
num2.fetch_add(1, Ordering::Relaxed);
|
||||
ok(io)
|
||||
})
|
||||
.and_then(
|
||||
HttpService::new(map_config(
|
||||
App::new().service(web::resource("/").route(web::to(|| async {
|
||||
HttpResponse::Ok().force_close().body(STR)
|
||||
}))),
|
||||
|_| AppConfig::default(),
|
||||
))
|
||||
.tcp(),
|
||||
)
|
||||
.and_then(HttpService::new(map_config(
|
||||
App::new().service(web::resource("/").route(web::to(|| async {
|
||||
HttpResponse::Ok().force_close().body(STR)
|
||||
}))),
|
||||
|_| AppConfig::default(),
|
||||
)))
|
||||
});
|
||||
|
||||
let client = Client::build()
|
||||
|
|
|
@ -8,8 +8,8 @@ use ntex::http::test::server as test_server;
|
|||
use ntex::http::{
|
||||
body, header, HttpService, KeepAlive, Method, Request, Response, StatusCode,
|
||||
};
|
||||
use ntex::time::{sleep, Millis};
|
||||
use ntex::{service::fn_service, time::Seconds, util::Bytes, web::error};
|
||||
use ntex::time::{sleep, Millis, Seconds};
|
||||
use ntex::{service::fn_service, util::Bytes, util::Ready, web::error};
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_h1() {
|
||||
|
@ -22,7 +22,6 @@ async fn test_h1() {
|
|||
assert!(req.peer_addr().is_some());
|
||||
future::ok::<_, io::Error>(Response::Ok().finish())
|
||||
})
|
||||
.tcp()
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
|
@ -41,7 +40,6 @@ async fn test_h1_2() {
|
|||
assert_eq!(req.version(), http::Version::HTTP_11);
|
||||
future::ok::<_, io::Error>(Response::Ok().finish())
|
||||
})
|
||||
.tcp()
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
|
@ -72,7 +70,6 @@ async fn test_expect_continue() {
|
|||
let _ = req.payload().next().await;
|
||||
Ok::<_, io::Error>(Response::Ok().finish())
|
||||
}))
|
||||
.tcp()
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
|
@ -101,22 +98,18 @@ async fn test_chunked_payload() {
|
|||
let total_size: usize = chunk_sizes.iter().sum();
|
||||
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(fn_service(|mut request: Request| {
|
||||
request
|
||||
.take_payload()
|
||||
.map(|res| match res {
|
||||
Ok(pl) => pl,
|
||||
Err(e) => panic!("Error reading payload: {}", e),
|
||||
})
|
||||
.fold(0usize, |acc, chunk| ready(acc + chunk.len()))
|
||||
.map(|req_size| {
|
||||
Ok::<_, io::Error>(
|
||||
Response::Ok().body(format!("size={}", req_size)),
|
||||
)
|
||||
})
|
||||
}))
|
||||
.tcp()
|
||||
HttpService::build().h1(fn_service(|mut request: Request| {
|
||||
request
|
||||
.take_payload()
|
||||
.map(|res| match res {
|
||||
Ok(pl) => pl,
|
||||
Err(e) => panic!("Error reading payload: {}", e),
|
||||
})
|
||||
.fold(0usize, |acc, chunk| ready(acc + chunk.len()))
|
||||
.map(|req_size| {
|
||||
Ok::<_, io::Error>(Response::Ok().body(format!("size={}", req_size)))
|
||||
})
|
||||
}))
|
||||
});
|
||||
|
||||
let returned_size = {
|
||||
|
@ -156,7 +149,6 @@ async fn test_slow_request() {
|
|||
HttpService::build()
|
||||
.client_timeout(Seconds(1))
|
||||
.finish(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
.tcp()
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
|
@ -169,9 +161,7 @@ async fn test_slow_request() {
|
|||
#[ntex::test]
|
||||
async fn test_http1_malformed_request() {
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
|
@ -184,9 +174,7 @@ async fn test_http1_malformed_request() {
|
|||
#[ntex::test]
|
||||
async fn test_http1_keepalive() {
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
|
@ -207,7 +195,6 @@ async fn test_http1_keepalive_timeout() {
|
|||
HttpService::build()
|
||||
.keep_alive(1)
|
||||
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
.tcp()
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
|
@ -225,9 +212,7 @@ async fn test_http1_keepalive_timeout() {
|
|||
#[ntex::test]
|
||||
async fn test_http1_keepalive_close() {
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
|
@ -245,9 +230,7 @@ async fn test_http1_keepalive_close() {
|
|||
#[ntex::test]
|
||||
async fn test_http10_keepalive_default_close() {
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
|
@ -264,9 +247,7 @@ async fn test_http10_keepalive_default_close() {
|
|||
#[ntex::test]
|
||||
async fn test_http10_keepalive() {
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
|
@ -293,7 +274,6 @@ async fn test_http1_keepalive_disabled() {
|
|||
HttpService::build()
|
||||
.keep_alive(KeepAlive::Disabled)
|
||||
.h1(|_| future::ok::<_, io::Error>(Response::Ok().finish()))
|
||||
.tcp()
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
|
@ -315,20 +295,18 @@ async fn test_content_length() {
|
|||
};
|
||||
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|req: Request| {
|
||||
let indx: usize = req.uri().path()[1..].parse().unwrap();
|
||||
let statuses = [
|
||||
StatusCode::NO_CONTENT,
|
||||
StatusCode::CONTINUE,
|
||||
StatusCode::SWITCHING_PROTOCOLS,
|
||||
StatusCode::PROCESSING,
|
||||
StatusCode::OK,
|
||||
StatusCode::NOT_FOUND,
|
||||
];
|
||||
future::ok::<_, io::Error>(Response::new(statuses[indx]))
|
||||
})
|
||||
.tcp()
|
||||
HttpService::build().h1(|req: Request| {
|
||||
let indx: usize = req.uri().path()[1..].parse().unwrap();
|
||||
let statuses = [
|
||||
StatusCode::NO_CONTENT,
|
||||
StatusCode::CONTINUE,
|
||||
StatusCode::SWITCHING_PROTOCOLS,
|
||||
StatusCode::PROCESSING,
|
||||
StatusCode::OK,
|
||||
StatusCode::NOT_FOUND,
|
||||
];
|
||||
future::ok::<_, io::Error>(Response::new(statuses[indx]))
|
||||
})
|
||||
});
|
||||
|
||||
let header = HeaderName::from_static("content-length");
|
||||
|
@ -362,7 +340,7 @@ async fn test_h1_headers() {
|
|||
let data = data.clone();
|
||||
HttpService::build().h1(move |_| {
|
||||
let mut builder = Response::Ok();
|
||||
for idx in 0..90 {
|
||||
for idx in 0..20 {
|
||||
builder.header(
|
||||
format!("X-TEST-{}", idx).as_str(),
|
||||
"TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST \
|
||||
|
@ -380,8 +358,9 @@ async fn test_h1_headers() {
|
|||
TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST TEST ",
|
||||
);
|
||||
}
|
||||
future::ok::<_, io::Error>(builder.body(data.clone()))
|
||||
}).tcp()
|
||||
println!("SENDING body");
|
||||
Ready::Ok::<_, io::Error>(builder.body(data.clone()))
|
||||
})
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
|
@ -417,9 +396,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
|
|||
#[ntex::test]
|
||||
async fn test_h1_body() {
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
|
@ -433,9 +410,7 @@ async fn test_h1_body() {
|
|||
#[ntex::test]
|
||||
async fn test_h1_head_empty() {
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
|
||||
});
|
||||
|
||||
let response = srv.request(http::Method::HEAD, "/").send().await.unwrap();
|
||||
|
@ -457,13 +432,9 @@ async fn test_h1_head_empty() {
|
|||
#[ntex::test]
|
||||
async fn test_h1_head_binary() {
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| {
|
||||
ok::<_, io::Error>(
|
||||
Response::Ok().content_length(STR.len() as u64).body(STR),
|
||||
)
|
||||
})
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| {
|
||||
ok::<_, io::Error>(Response::Ok().content_length(STR.len() as u64).body(STR))
|
||||
})
|
||||
});
|
||||
|
||||
let response = srv.request(http::Method::HEAD, "/").send().await.unwrap();
|
||||
|
@ -485,9 +456,7 @@ async fn test_h1_head_binary() {
|
|||
#[ntex::test]
|
||||
async fn test_h1_head_binary2() {
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| ok::<_, io::Error>(Response::Ok().body(STR)))
|
||||
});
|
||||
|
||||
let response = srv.request(http::Method::HEAD, "/").send().await.unwrap();
|
||||
|
@ -505,14 +474,12 @@ async fn test_h1_head_binary2() {
|
|||
#[ntex::test]
|
||||
async fn test_h1_body_length() {
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| {
|
||||
let body = once(ok(Bytes::from_static(STR.as_ref())));
|
||||
ok::<_, io::Error>(
|
||||
Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)),
|
||||
)
|
||||
})
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| {
|
||||
let body = once(ok(Bytes::from_static(STR.as_ref())));
|
||||
ok::<_, io::Error>(
|
||||
Response::Ok().body(body::SizedStream::new(STR.len() as u64, body)),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
|
@ -525,18 +492,15 @@ async fn test_h1_body_length() {
|
|||
|
||||
#[ntex::test]
|
||||
async fn test_h1_body_chunked_explicit() {
|
||||
env_logger::init();
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| {
|
||||
let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref())));
|
||||
ok::<_, io::Error>(
|
||||
Response::Ok()
|
||||
.header(header::TRANSFER_ENCODING, "chunked")
|
||||
.streaming(body),
|
||||
)
|
||||
})
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| {
|
||||
let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref())));
|
||||
ok::<_, io::Error>(
|
||||
Response::Ok()
|
||||
.header(header::TRANSFER_ENCODING, "chunked")
|
||||
.streaming(body),
|
||||
)
|
||||
})
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
|
@ -561,12 +525,10 @@ async fn test_h1_body_chunked_explicit() {
|
|||
#[ntex::test]
|
||||
async fn test_h1_body_chunked_implicit() {
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| {
|
||||
let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref())));
|
||||
ok::<_, io::Error>(Response::Ok().streaming(body))
|
||||
})
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| {
|
||||
let body = once(ok::<_, io::Error>(Bytes::from_static(STR.as_ref())));
|
||||
ok::<_, io::Error>(Response::Ok().streaming(body))
|
||||
})
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
|
@ -589,16 +551,14 @@ async fn test_h1_body_chunked_implicit() {
|
|||
#[ntex::test]
|
||||
async fn test_h1_response_http_error_handling() {
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(fn_service(|_| {
|
||||
let broken_header = Bytes::from_static(b"\0\0\0");
|
||||
ok::<_, io::Error>(
|
||||
Response::Ok()
|
||||
.header(http::header::CONTENT_TYPE, &broken_header[..])
|
||||
.body(STR),
|
||||
)
|
||||
}))
|
||||
.tcp()
|
||||
HttpService::build().h1(fn_service(|_| {
|
||||
let broken_header = Bytes::from_static(b"\0\0\0");
|
||||
ok::<_, io::Error>(
|
||||
Response::Ok()
|
||||
.header(http::header::CONTENT_TYPE, &broken_header[..])
|
||||
.body(STR),
|
||||
)
|
||||
}))
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
|
@ -612,14 +572,12 @@ async fn test_h1_response_http_error_handling() {
|
|||
#[ntex::test]
|
||||
async fn test_h1_service_error() {
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.h1(|_| {
|
||||
future::err::<Response, _>(error::InternalError::default(
|
||||
"error",
|
||||
StatusCode::BAD_REQUEST,
|
||||
))
|
||||
})
|
||||
.tcp()
|
||||
HttpService::build().h1(|_| {
|
||||
future::err::<Response, _>(error::InternalError::default(
|
||||
"error",
|
||||
StatusCode::BAD_REQUEST,
|
||||
))
|
||||
})
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
|
@ -629,19 +587,3 @@ async fn test_h1_service_error() {
|
|||
let bytes = srv.load_body(response).await.unwrap();
|
||||
assert_eq!(bytes, Bytes::from_static(b"error"));
|
||||
}
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_h1_on_connect() {
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.on_connect(|_| 10usize)
|
||||
.h1(|req: Request| {
|
||||
assert!(req.extensions().contains::<usize>());
|
||||
future::ok::<_, io::Error>(Response::Ok().finish())
|
||||
})
|
||||
.tcp()
|
||||
});
|
||||
|
||||
let response = srv.request(Method::GET, "/").send().await.unwrap();
|
||||
assert!(response.status().is_success());
|
||||
}
|
||||
|
|
|
@ -3,10 +3,9 @@ use std::sync::{mpsc, Arc};
|
|||
use std::{io, io::Read, net, thread, time};
|
||||
|
||||
use futures::future::{lazy, ok, FutureExt};
|
||||
use futures::SinkExt;
|
||||
|
||||
use ntex::codec::{BytesCodec, Framed};
|
||||
use ntex::rt::net::TcpStream;
|
||||
use ntex::codec::BytesCodec;
|
||||
use ntex::io::Io;
|
||||
use ntex::server::{Server, TestServer};
|
||||
use ntex::service::fn_service;
|
||||
use ntex::util::{Bytes, Ready};
|
||||
|
@ -77,9 +76,10 @@ fn test_start() {
|
|||
.backlog(100)
|
||||
.disable_signals()
|
||||
.bind("test", addr, move || {
|
||||
fn_service(|io: TcpStream| async move {
|
||||
let mut f = Framed::new(io, BytesCodec);
|
||||
f.send(Bytes::from_static(b"test")).await.unwrap();
|
||||
fn_service(|io: Io| async move {
|
||||
io.send(Bytes::from_static(b"test"), &BytesCodec)
|
||||
.await
|
||||
.unwrap();
|
||||
Ok::<_, ()>(())
|
||||
})
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue