restore h2 impl

This commit is contained in:
Nikolay Kim 2021-12-15 21:32:24 +06:00
parent 3dbba47ab1
commit 399b238621
11 changed files with 180 additions and 147 deletions

View file

@ -425,7 +425,6 @@ impl IoRef {
let is_write_sleep = buf.is_empty();
codec.encode(item, &mut buf).map_err(Either::Left)?;
filter.release_write_buf(buf).map_err(Either::Right)?;
self.0.insert_flags(Flags::WR_WAIT);
if is_write_sleep {
self.0.write_task.wake();
}
@ -681,7 +680,7 @@ impl<'a> WriteRef<'a> {
}
#[inline]
/// Write item to a buffer and wake up write task
/// Encode and write item to a buffer and wake up write task
///
/// Returns write buffer state, false is returned if write buffer if full.
pub fn encode<U>(
@ -724,6 +723,36 @@ impl<'a> WriteRef<'a> {
}
}
#[inline]
/// Write item to a buffer and wake up write task
///
/// Returns write buffer state, false is returned if write buffer if full.
pub fn write(&self, src: &[u8]) -> Result<bool, io::Error> {
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let filter = self.0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
let is_write_sleep = buf.is_empty();
// write and wake write task
buf.extend_from_slice(src);
let result = buf.len() < self.0.pool.get().write_params_high();
if is_write_sleep {
self.0.write_task.wake();
}
if let Err(err) = filter.release_write_buf(buf) {
self.0.set_error(Some(err));
}
Ok(result)
} else {
Ok(true)
}
}
#[inline]
/// Wake write task and instruct to write data.
///

View file

@ -1,11 +1,12 @@
use std::task::{Context, Poll};
use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc};
use std::{cell::RefCell, cmp, future::Future, io, pin::Pin, rc::Rc};
use ntex_bytes::{Buf, BufMut};
use ntex_util::time::{sleep, Sleep};
use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf, net::TcpStream};
use tok_io::io::{AsyncRead, AsyncWrite, ReadBuf};
use tok_io::net::TcpStream;
use super::{IoStream, ReadContext, WriteContext, WriteReadiness};
use super::{Filter, Io, IoStream, ReadContext, WriteContext, WriteReadiness};
impl IoStream for TcpStream {
fn start(self, read: ReadContext, write: WriteContext) {
@ -340,3 +341,50 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
Poll::Ready(true)
}
}
impl<F: Filter> AsyncRead for Io<F> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let read = self.read();
let len = read.with_buf(|src| {
let len = cmp::min(src.len(), buf.capacity());
buf.put_slice(&src.split_to(len));
len
});
if len == 0 && !self.0.is_io_open() {
if let Some(err) = self.0.take_error() {
return Poll::Ready(Err(err));
}
}
if read.poll_ready(cx)?.is_ready() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
impl<F: Filter> AsyncWrite for Io<F> {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.write().write(buf).map(|_| buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.write().poll_flush(cx, false)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.0.poll_shutdown(cx)
}
}

View file

@ -4,12 +4,12 @@ use crate::http::body::MessageBody;
use crate::http::config::{KeepAlive, OnRequest, ServiceConfig};
use crate::http::error::ResponseError;
use crate::http::h1::{Codec, ExpectHandler, H1Service, UpgradeHandler};
use crate::io::{Filter, Io, IoRef};
// use crate::http::h2::H2Service;
use crate::http::h2::H2Service;
use crate::http::helpers::{Data, DataFactory};
use crate::http::request::Request;
use crate::http::response::Response;
use crate::http::service::HttpService;
use crate::io::{Filter, Io, IoRef};
use crate::service::{boxed, IntoService, IntoServiceFactory, Service, ServiceFactory};
use crate::time::{Millis, Seconds};
use crate::util::PoolId;
@ -214,9 +214,8 @@ where
.on_request(self.on_request)
}
// pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B>
/// Finish service configuration and create *http service* for HTTP/2 protocol.
pub fn h2<B, SF>(self, service: SF) -> H1Service<F, S, B, X, U>
pub fn h2<B, SF>(self, service: SF) -> H2Service<F, S, B>
where
B: MessageBody + 'static,
SF: IntoServiceFactory<S>,
@ -233,11 +232,7 @@ where
self.pool,
);
// H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect)
H1Service::with_config(cfg, service.into_factory())
.expect(self.expect)
.upgrade(self.upgrade)
.on_request(self.on_request)
H2Service::with_config(cfg, service.into_factory())
}
/// Finish service configuration and create `HttpService` instance.

View file

@ -310,7 +310,7 @@ impl WsRequest {
let fut = self.config.connector.open_tunnel(head.into(), self.addr);
// set request timeout
let (head, io, codec) = if self.config.timeout.non_zero() {
let (head, io, _) = if self.config.timeout.non_zero() {
timeout(self.config.timeout, fut)
.await
.map_err(|_| SendRequestError::Timeout)

View file

@ -17,7 +17,8 @@ use crate::http::message::ResponseHead;
use crate::http::payload::Payload;
use crate::http::request::Request;
use crate::http::response::Response;
use crate::time::Sleep;
use crate::io::{Filter, Io, IoRef};
use crate::time::{now, Sleep};
use crate::util::{Bytes, BytesMut};
use crate::Service;
@ -25,31 +26,29 @@ const CHUNK_SIZE: usize = 16_384;
pin_project_lite::pin_project! {
/// Dispatcher for HTTP/2 protocol
pub struct Dispatcher<T, S: Service<Request = Request>, B: MessageBody, X, U> {
config: Rc<DispatcherConfig<T, S, X, U>>,
connection: Connection<T, Bytes>,
on_connect: Option<Box<dyn DataFactory>>,
peer_addr: Option<net::SocketAddr>,
pub struct Dispatcher<F, S: Service<Request = Request>, B: MessageBody, X, U> {
io: IoRef,
config: Rc<DispatcherConfig<S, X, U>>,
connection: Connection<Io<F>, Bytes>,
ka_expire: time::Instant,
ka_timer: Option<Sleep>,
_t: PhantomData<B>,
}
}
impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
impl<F, S, B, X, U> Dispatcher<F, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
F: Filter,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>>,
B: MessageBody,
{
pub(in crate::http) fn new(
config: Rc<DispatcherConfig<T, S, X, U>>,
connection: Connection<T, Bytes>,
on_connect: Option<Box<dyn DataFactory>>,
io: IoRef,
config: Rc<DispatcherConfig<S, X, U>>,
connection: Connection<Io<F>, Bytes>,
timeout: Option<Sleep>,
peer_addr: Option<net::SocketAddr>,
) -> Self {
// keep-alive timer
let (ka_expire, ka_timer) = if let Some(delay) = timeout {
@ -61,14 +60,13 @@ where
config.timer.now() + std::time::Duration::from(config.keep_alive);
(expire, Some(delay))
} else {
(config.now(), None)
(now(), None)
};
Dispatcher {
io,
config,
peer_addr,
connection,
on_connect,
ka_expire,
ka_timer,
_t: PhantomData,
@ -76,9 +74,9 @@ where
}
}
impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
impl<F, S, B, X, U> Future for Dispatcher<F, S, B, X, U>
where
T: AsyncRead + AsyncWrite + Unpin,
F: Filter,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Future: 'static,
@ -115,12 +113,7 @@ where
head.method = parts.method;
head.version = parts.version;
head.headers = parts.headers.into();
head.peer_addr = this.peer_addr;
// set on_connect data
if let Some(ref on_connect) = this.on_connect {
on_connect.set(&mut req.extensions_mut());
}
head.io = Some(this.io.clone());
crate::rt::spawn(ServiceResponse {
state: ServiceResponseState::ServiceCall {

View file

@ -4,11 +4,11 @@ use std::task::{Context, Poll};
use h2::RecvStream;
//mod dispatcher;
//mod service;
mod dispatcher;
mod service;
//pub use self::dispatcher::Dispatcher;
//pub use self::service::H2Service;
pub use self::dispatcher::Dispatcher;
pub use self::service::H2Service;
use crate::{http::error::PayloadError, util::Bytes, Stream};
/// H2 receive stream

View file

@ -4,13 +4,13 @@ use std::{future::Future, marker::PhantomData, net, pin::Pin, rc::Rc};
use h2::server::{self, Handshake};
use log::error;
use crate::codec::{AsyncRead, AsyncWrite};
use crate::http::body::MessageBody;
use crate::http::config::{DispatcherConfig, ServiceConfig};
use crate::http::error::{DispatchError, ResponseError};
use crate::http::helpers::DataFactory;
use crate::http::request::Request;
use crate::http::response::Response;
use crate::io::{Filter, Io, IoRef};
use crate::rt::net::TcpStream;
use crate::service::{
fn_factory, fn_service, pipeline_factory, IntoServiceFactory, Service,
@ -22,16 +22,15 @@ use crate::util::Bytes;
use super::dispatcher::Dispatcher;
/// `ServiceFactory` implementation for HTTP2 transport
pub struct H2Service<T, S, B> {
pub struct H2Service<F, S, B> {
srv: S,
cfg: ServiceConfig,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
#[allow(dead_code)]
handshake_timeout: Millis,
_t: PhantomData<(T, B)>,
_t: PhantomData<(F, B)>,
}
impl<T, S, B> H2Service<T, S, B>
impl<F, S, B> H2Service<F, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
@ -41,56 +40,17 @@ where
B: MessageBody + 'static,
{
/// Create new `HttpService` instance with config.
pub(crate) fn with_config<F: IntoServiceFactory<S>>(
pub(crate) fn with_config<U: IntoServiceFactory<S>>(
cfg: ServiceConfig,
service: F,
service: U,
) -> Self {
H2Service {
on_connect: None,
srv: service.into_factory(),
handshake_timeout: cfg.0.ssl_handshake_timeout,
_t: PhantomData,
cfg,
}
}
/// Set on connect callback.
pub(crate) fn on_connect(
mut self,
f: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
) -> Self {
self.on_connect = f;
self
}
}
impl<S, B> H2Service<TcpStream, S, B>
where
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>> + 'static,
S::Future: 'static,
<S::Service as Service>::Future: 'static,
B: MessageBody + 'static,
{
/// Create simple tcp based service
pub fn tcp(
self,
) -> impl ServiceFactory<
Config = (),
Request = TcpStream,
Response = (),
Error = DispatchError,
InitError = S::InitError,
> {
pipeline_factory(fn_factory(|| async {
Ok::<_, S::InitError>(fn_service(|io: TcpStream| async move {
let peer_addr = io.peer_addr().ok();
Ok::<_, DispatchError>((io, peer_addr))
}))
}))
.and_then(self)
}
}
#[cfg(feature = "openssl")]
@ -188,9 +148,9 @@ mod rustls {
}
}
impl<T, S, B> ServiceFactory for H2Service<T, S, B>
impl<F, S, B> ServiceFactory for H2Service<F, S, B>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
F: Filter,
S: ServiceFactory<Config = (), Request = Request>,
S::Error: ResponseError + 'static,
S::Response: Into<Response<B>> + 'static,
@ -199,17 +159,16 @@ where
B: MessageBody + 'static,
{
type Config = ();
type Request = (T, Option<net::SocketAddr>);
type Request = Io<F>;
type Response = ();
type Error = DispatchError;
type InitError = S::InitError;
type Service = H2ServiceHandler<T, S::Service, B>;
type Service = H2ServiceHandler<F, S::Service, B>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Service, Self::InitError>>>>;
fn new_service(&self, _: ()) -> Self::Future {
let fut = self.srv.new_service(());
let cfg = self.cfg.clone();
let on_connect = self.on_connect.clone();
Box::pin(async move {
let service = fut.await?;
@ -217,7 +176,6 @@ where
Ok(H2ServiceHandler {
config,
on_connect,
_t: PhantomData,
})
})
@ -225,25 +183,24 @@ where
}
/// `Service` implementation for http/2 transport
pub struct H2ServiceHandler<T, S: Service, B> {
config: Rc<DispatcherConfig<T, S, (), ()>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
_t: PhantomData<(T, B)>,
pub struct H2ServiceHandler<F, S: Service, B> {
config: Rc<DispatcherConfig<S, (), ()>>,
_t: PhantomData<(F, B)>,
}
impl<T, S, B> Service for H2ServiceHandler<T, S, B>
impl<F, S, B> Service for H2ServiceHandler<F, S, B>
where
T: AsyncRead + AsyncWrite + Unpin,
F: Filter,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
{
type Request = (T, Option<net::SocketAddr>);
type Request = Io<F>;
type Response = ();
type Error = DispatchError;
type Future = H2ServiceHandlerResponse<T, S, B>;
type Future = H2ServiceHandlerResponse<F, S, B>;
#[inline]
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -258,49 +215,47 @@ where
self.config.service.poll_shutdown(cx, is_error)
}
fn call(&self, (io, addr): Self::Request) -> Self::Future {
trace!("New http2 connection, peer address: {:?}", addr);
fn call(&self, io: Self::Request) -> Self::Future {
// trace!("New http2 connection, peer address: {:?}", addr);
H2ServiceHandlerResponse {
state: State::Handshake(
io.get_ref(),
self.config.clone(),
addr,
self.on_connect.as_ref().map(|f| f(&io)),
server::Builder::new().handshake(io),
),
}
}
}
enum State<T, S: Service<Request = Request>, B: MessageBody>
enum State<F, S: Service<Request = Request>, B: MessageBody>
where
T: AsyncRead + AsyncWrite + Unpin,
F: Filter,
S::Future: 'static,
{
Incoming(Dispatcher<T, S, B, (), ()>),
Incoming(Dispatcher<F, S, B, (), ()>),
Handshake(
Rc<DispatcherConfig<T, S, (), ()>>,
Option<net::SocketAddr>,
Option<Box<dyn DataFactory>>,
Handshake<T, Bytes>,
IoRef,
Rc<DispatcherConfig<S, (), ()>>,
Handshake<Io<F>, Bytes>,
),
}
pub struct H2ServiceHandlerResponse<T, S, B>
pub struct H2ServiceHandlerResponse<F, S, B>
where
T: AsyncRead + AsyncWrite + Unpin,
F: Filter,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Future: 'static,
S::Response: Into<Response<B>> + 'static,
B: MessageBody + 'static,
{
state: State<T, S, B>,
state: State<F, S, B>,
}
impl<T, S, B> Future for H2ServiceHandlerResponse<T, S, B>
impl<F, S, B> Future for H2ServiceHandlerResponse<F, S, B>
where
T: AsyncRead + AsyncWrite + Unpin,
F: Filter,
S: Service<Request = Request>,
S::Error: ResponseError + 'static,
S::Future: 'static,
@ -312,29 +267,25 @@ where
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state {
State::Incoming(ref mut disp) => Pin::new(disp).poll(cx),
State::Handshake(
ref config,
peer_addr,
ref mut on_connect,
ref mut handshake,
) => match Pin::new(handshake).poll(cx) {
Poll::Ready(Ok(conn)) => {
trace!("H2 handshake completed");
self.state = State::Incoming(Dispatcher::new(
config.clone(),
conn,
on_connect.take(),
None,
peer_addr,
));
self.poll(cx)
State::Handshake(ref io, ref config, ref mut handshake) => {
match Pin::new(handshake).poll(cx) {
Poll::Ready(Ok(conn)) => {
trace!("H2 handshake completed");
self.state = State::Incoming(Dispatcher::new(
io.clone(),
config.clone(),
conn,
None,
));
self.poll(cx)
}
Poll::Ready(Err(err)) => {
trace!("H2 handshake error: {}", err);
Poll::Ready(Err(err.into()))
}
Poll::Pending => Poll::Pending,
}
Poll::Ready(Err(err)) => {
trace!("H2 handshake error: {}", err);
Poll::Ready(Err(err.into()))
}
Poll::Pending => Poll::Pending,
},
}
}
}
}

View file

@ -167,6 +167,17 @@ impl RequestHead {
pub(crate) fn set_expect(&mut self) {
self.flags.insert(Flags::EXPECT);
}
/// Peer socket address
///
/// Peer address is actual socket address, if proxy is used in front of
/// ntex http server, then peer address would be address of this proxy.
#[inline]
pub fn peer_addr(&self) -> Option<net::SocketAddr> {
// TODO! fix
// self.head().peer_addr
None
}
}
#[derive(Debug)]

View file

@ -112,6 +112,17 @@ impl HttpRequest {
self.head().io.as_ref()
}
/// Peer socket address
///
/// Peer address is actual socket address, if proxy is used in front of
/// ntex http server, then peer address would be address of this proxy.
#[inline]
pub fn peer_addr(&self) -> Option<net::SocketAddr> {
// TODO! fix
// self.head().peer_addr
None
}
/// Get a reference to the Path parameters.
///
/// Params is a container for url parameters.

View file

@ -119,9 +119,7 @@ impl ConnectionInfo {
}
if remote.is_none() {
// get peeraddr from socketaddr
// TODO! fix
// peer = req.peer_addr.map(|addr| format!("{}", addr));
peer = req.peer_addr().map(|addr| format!("{}", addr));
}
}

View file

@ -965,10 +965,7 @@ mod tests {
.to_http_request();
assert!(req.headers().contains_key(header::CONTENT_TYPE));
assert!(req.headers().contains_key(header::DATE));
assert_eq!(
req.head().peer_addr,
Some("127.0.0.1:8081".parse().unwrap())
);
assert_eq!(req.peer_addr(), Some("127.0.0.1:8081".parse().unwrap()));
assert_eq!(&req.match_info()["test"], "123");
assert_eq!(req.version(), Version::HTTP_2);
let data = req.app_data::<web::types::Data<u64>>().unwrap();