Migrate to tokio 1.0 (#41)

* migrate to tokio 1.x

* update tests
This commit is contained in:
Nikolay Kim 2021-02-24 00:12:44 +06:00 committed by GitHub
parent ddd973b808
commit e04ae7cc86
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 557 additions and 807 deletions

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.4.0] - 2021-02-23
* Migrate to tokio 1.x
## [0.3.0] - 2021-02-20 ## [0.3.0] - 2021-02-20
* Make Encoder and Decoder methods immutable * Make Encoder and Decoder methods immutable

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-codec" name = "ntex-codec"
version = "0.3.0" version = "0.4.0"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames" description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -17,13 +17,13 @@ path = "src/lib.rs"
[dependencies] [dependencies]
bitflags = "1.2.1" bitflags = "1.2.1"
bytes = "0.5.6" bytes = "1.0"
either = "1.6.1" either = "1.6.1"
futures-core = "0.3.12" futures-core = "0.3.12"
futures-sink = "0.3.12" futures-sink = "0.3.12"
log = "0.4" log = "0.4"
tokio = { version = "0.2.6", default-features=false } tokio = { version = "1", default-features=false }
[dev-dependencies] [dev-dependencies]
ntex = "0.2.0-b.13" ntex = "0.3.0-b.1"
futures = "0.3.12" futures = "0.3.13"

View file

@ -1,4 +1,4 @@
use bytes::{Buf, Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use std::io; use std::io;
use super::{Decoder, Encoder}; use super::{Decoder, Encoder};
@ -15,7 +15,7 @@ impl Encoder for BytesCodec {
#[inline] #[inline]
fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(&self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.extend_from_slice(item.bytes()); dst.extend_from_slice(&item[..]);
Ok(()) Ok(())
} }
} }

View file

@ -306,9 +306,10 @@ where
// read until 0 or err // read until 0 or err
let mut buf = [0u8; 512]; let mut buf = [0u8; 512];
let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
loop { loop {
match ready!(Pin::new(&mut self.io).poll_read(cx, &mut buf)) { match ready!(Pin::new(&mut self.io).poll_read(cx, &mut read_buf)) {
Err(_) | Ok(0) => { Err(_) | Ok(_) if read_buf.filled().is_empty() => {
break; break;
} }
_ => (), _ => (),
@ -387,7 +388,11 @@ where
if remaining < LW { if remaining < LW {
self.read_buf.reserve(HW - remaining) self.read_buf.reserve(HW - remaining)
} }
match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) { match crate::poll_read_buf(
Pin::new(&mut self.io),
cx,
&mut self.read_buf,
) {
Poll::Pending => { Poll::Pending => {
if updated { if updated {
done_read = true; done_read = true;

View file

@ -7,6 +7,7 @@
//! [`AsyncRead`]: # //! [`AsyncRead`]: #
//! [`AsyncWrite`]: # //! [`AsyncWrite`]: #
#![deny(rust_2018_idioms, warnings)] #![deny(rust_2018_idioms, warnings)]
use std::{io, mem::MaybeUninit, pin::Pin, task::Context, task::Poll};
mod bcodec; mod bcodec;
mod decoder; mod decoder;
@ -18,4 +19,38 @@ pub use self::decoder::Decoder;
pub use self::encoder::Encoder; pub use self::encoder::Encoder;
pub use self::framed::{Framed, FramedParts}; pub use self::framed::{Framed, FramedParts};
pub use tokio::io::{AsyncRead, AsyncWrite}; pub use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use bytes::BufMut;
pub fn poll_read_buf<T: AsyncRead>(
io: Pin<&mut T>,
cx: &mut Context<'_>,
buf: &mut bytes::BytesMut,
) -> Poll<io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let n = {
let dst = buf.chunk_mut();
let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
if io.poll_read(cx, &mut buf)?.is_pending() {
return Poll::Pending;
}
// Ensure the pointer does not change from under us
assert_eq!(ptr, buf.filled().as_ptr());
buf.filled().len()
};
// Safety: This is guaranteed to be the number of initialized (and read)
// bytes due to the invariants provided by `ReadBuf::filled`.
unsafe {
buf.advance_mut(n);
}
Poll::Ready(Ok(n))
}

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-macros" name = "ntex-macros"
version = "0.1.0" version = "0.1.1"
description = "ntex proc macros" description = "ntex proc macros"
readme = "README.md" readme = "README.md"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.4.0] - 2021-02-23
* Upgrade to bytestring 1.0
## [0.3.8] - 2020-10-28 ## [0.3.8] - 2020-10-28
* Router struct implements Clone trait * Router struct implements Clone trait

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-router" name = "ntex-router"
version = "0.3.8" version = "0.4.0"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Path router" description = "Path router"
keywords = ["ntex"] keywords = ["ntex"]
@ -19,7 +19,7 @@ default = ["http"]
[dependencies] [dependencies]
regex = "1.4.1" regex = "1.4.1"
serde = "1.0.116" serde = "1.0.116"
bytestring = "0.1.5" bytestring = "1.0"
log = "0.4.8" log = "0.4.8"
http = { version = "0.2.1", optional = true } http = { version = "0.2.1", optional = true }

View file

@ -17,4 +17,4 @@ quote = "^1"
syn = { version = "^1", features = ["full"] } syn = { version = "^1", features = ["full"] }
[dev-dependencies] [dev-dependencies]
ntex = { version = "0.1.0" } ntex = { version = "0.3.0-b.1" }

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.2.0] - 2021-02-23
* Migrate to tokio 1.x
## [0.1.2] - 2021-01-25 ## [0.1.2] - 2021-01-25
* Replace actix-threadpool with tokio's task utils * Replace actix-threadpool with tokio's task utils

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-rt" name = "ntex-rt"
version = "0.1.2" version = "0.2.0"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "ntex runtime" description = "ntex runtime"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -17,5 +17,5 @@ path = "src/lib.rs"
[dependencies] [dependencies]
ntex-rt-macros = "0.1.0" ntex-rt-macros = "0.1.0"
futures = "0.3.12" futures = "0.3.13"
tokio = { version = "0.2.6", default-features=false, features = ["rt-core", "rt-util", "io-driver", "blocking", "tcp", "uds", "udp", "time", "signal", "stream"] } tokio = { version = "1", default-features=false, features = ["rt", "net", "time", "signal"] }

View file

@ -93,7 +93,7 @@ impl Arbiter {
let handle = thread::Builder::new() let handle = thread::Builder::new()
.name(name.clone()) .name(name.clone())
.spawn(move || { .spawn(move || {
let mut rt = Runtime::new().expect("Can not create Runtime"); let rt = Runtime::new().expect("Can not create Runtime");
let arb = Arbiter::with_sender(arb_tx); let arb = Arbiter::with_sender(arb_tx);
let (stop, stop_rx) = channel(); let (stop, stop_rx) = channel();
@ -132,7 +132,7 @@ impl Arbiter {
} }
/// Send a future to the Arbiter's thread, and spawn it. /// Send a future to the Arbiter's thread, and spawn it.
pub fn send<F>(&self, future: F) pub fn spawn<F>(&self, future: F)
where where
F: Future<Output = ()> + Send + Unpin + 'static, F: Future<Output = ()> + Send + Unpin + 'static,
{ {

View file

@ -98,7 +98,7 @@ impl Builder {
let (stop_tx, stop) = channel(); let (stop_tx, stop) = channel();
let (sys_sender, sys_receiver) = unbounded(); let (sys_sender, sys_receiver) = unbounded();
let mut rt = Runtime::new().unwrap(); let rt = Runtime::new().unwrap();
// system arbiter // system arbiter
let system = System::construct( let system = System::construct(
@ -161,7 +161,7 @@ impl SystemRunner {
/// This function will start event loop and will finish once the /// This function will start event loop and will finish once the
/// `System::stop()` function is called. /// `System::stop()` function is called.
pub fn run(self) -> io::Result<()> { pub fn run(self) -> io::Result<()> {
let SystemRunner { mut rt, stop, .. } = self; let SystemRunner { rt, stop, .. } = self;
// run loop // run loop
match rt.block_on(stop) { match rt.block_on(stop) {
@ -210,7 +210,9 @@ mod tests {
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
thread::spawn(move || { thread::spawn(move || {
let mut rt = tokio::runtime::Runtime::new().unwrap(); let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let local = tokio::task::LocalSet::new(); let local = tokio::task::LocalSet::new();
let runner = crate::System::build() let runner = crate::System::build()
@ -237,7 +239,7 @@ mod tests {
assert_eq!(id, id2); assert_eq!(id, id2);
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
sys.arbiter().send(Box::pin(async move { sys.arbiter().spawn(Box::pin(async move {
let _ = tx.send(System::current().id()); let _ = tx.send(System::current().id());
})); }));
let id2 = rx.recv().unwrap(); let id2 = rx.recv().unwrap();

View file

@ -71,8 +71,11 @@ pub mod net {
/// Utilities for tracking time. /// Utilities for tracking time.
pub mod time { pub mod time {
pub use tokio::time::Instant; pub use tokio::time::Instant;
pub use tokio::time::{delay_for, delay_until, Delay};
pub use tokio::time::{interval, interval_at, Interval}; pub use tokio::time::{interval, interval_at, Interval};
pub use tokio::time::{sleep, sleep_until, Sleep};
pub use tokio::time::{
sleep as delay_for, sleep_until as delay_until, Sleep as Delay,
};
pub use tokio::time::{timeout, Timeout}; pub use tokio::time::{timeout, Timeout};
} }

View file

@ -18,10 +18,9 @@ impl Runtime {
#[allow(clippy::new_ret_no_self)] #[allow(clippy::new_ret_no_self)]
/// Returns a new runtime initialized with default configuration values. /// Returns a new runtime initialized with default configuration values.
pub fn new() -> io::Result<Runtime> { pub fn new() -> io::Result<Runtime> {
let rt = runtime::Builder::new() let rt = runtime::Builder::new_current_thread()
.enable_io() .enable_io()
.enable_time() .enable_time()
.basic_scheduler()
.build()?; .build()?;
Ok(Runtime { Ok(Runtime {
@ -86,10 +85,10 @@ impl Runtime {
/// ///
/// The caller is responsible for ensuring that other spawned futures /// The caller is responsible for ensuring that other spawned futures
/// complete execution by calling `block_on` or `run`. /// complete execution by calling `block_on` or `run`.
pub fn block_on<F>(&mut self, f: F) -> F::Output pub fn block_on<F>(&self, f: F) -> F::Output
where where
F: Future, F: Future,
{ {
self.local.block_on(&mut self.rt, f) self.local.block_on(&self.rt, f)
} }
} }

View file

@ -16,8 +16,8 @@ name = "ntex_service"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
futures-util = "0.3.9" futures-util = "0.3.13"
pin-project-lite = "0.2.4" pin-project-lite = "0.2.4"
[dev-dependencies] [dev-dependencies]
ntex-rt = "0.1" ntex-rt = "0.2"

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.3.0-b.1] - 2021-02-22
* Migrate to tokio 1.x
## [0.2.1] - 2021-02-22 ## [0.2.1] - 2021-02-22
* http: Fix http date header update task * http: Fix http date header update task

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex" name = "ntex"
version = "0.2.1" version = "0.3.0"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services" description = "Framework for composable network services"
readme = "README.md" readme = "README.md"
@ -36,28 +36,28 @@ compress = ["flate2", "brotli2"]
cookie = ["coo-kie", "coo-kie/percent-encode"] cookie = ["coo-kie", "coo-kie/percent-encode"]
[dependencies] [dependencies]
ntex-codec = "0.3.0" ntex-codec = "0.4.0"
ntex-rt = "0.1.2" ntex-rt = "0.2.0"
ntex-rt-macros = "0.1" ntex-rt-macros = "0.1"
ntex-router = "0.3.8" ntex-router = "0.4.0"
ntex-service = "0.1.5" ntex-service = "0.1.5"
ntex-macros = "0.1" ntex-macros = "0.1"
base64 = "0.13" base64 = "0.13"
bitflags = "1.2.1" bitflags = "1.2.1"
bytes = "0.5.6" bytes = "1.0"
bytestring = "0.1.5" bytestring = "1.0"
derive_more = "0.99.11" derive_more = "0.99.11"
either = "1.6.1" either = "1.6.1"
encoding_rs = "0.8.26" encoding_rs = "0.8.26"
futures = "0.3.12" futures = "0.3.13"
ahash = "0.7.0" ahash = "0.7.0"
h2 = "0.2.4" h2 = "0.3"
http = "0.2.1" http = "0.2.1"
httparse = "1.3" httparse = "1.3"
log = "0.4" log = "0.4"
mime = "0.3" mime = "0.3"
mio = "0.6.22" mio = "0.7.9"
num_cpus = "1.13" num_cpus = "1.13"
percent-encoding = "2.1" percent-encoding = "2.1"
pin-project-lite = "0.2.4" pin-project-lite = "0.2.4"
@ -70,35 +70,31 @@ serde_json = "1.0"
serde_urlencoded = "0.7.0" serde_urlencoded = "0.7.0"
socket2 = "0.3.12" socket2 = "0.3.12"
url = "2.1" url = "2.1"
time = { version = "0.2.15", default-features = false, features = ["std"] }
coo-kie = { version = "0.14.2", package = "cookie", optional = true } coo-kie = { version = "0.14.2", package = "cookie", optional = true }
tokio = "0.2.6" time = { version = "0.2.15", default-features = false, features = ["std"] }
tokio = { version = "1", default-features=false }
# resolver # resolver
trust-dns-proto = { version = "0.19.6", default-features = false } trust-dns-proto = { version = "0.20.0", default-features = false }
trust-dns-resolver = { version = "0.19.6", default-features = false, features=["system-config"] } trust-dns-resolver = { version = "0.20.0", default-features = false, features=["system-config", "tokio-runtime"] }
async-trait = "0.1.27" # this is only for trust-dns
# FIXME: Remove it and use mio own uds feature once mio 0.7 is released
mio-uds = { version = "0.6.7" }
# openssl # openssl
open-ssl = { version="0.10", package = "openssl", optional = true } open-ssl = { version="0.10", package = "openssl", optional = true }
tokio-openssl = { version = "0.4.0", optional = true } tokio-openssl = { version = "0.6.1", optional = true }
# rustls # rustls
rust-tls = { version = "0.19.0", package = "rustls", optional = true } rust-tls = { version = "0.19.0", package = "rustls", optional = true }
webpki = { version = "0.21.2", optional = true } webpki = { version = "0.21.4", optional = true }
webpki-roots = { version = "0.21.0", optional = true } webpki-roots = { version = "0.21.0", optional = true }
tokio-rustls = { version = "0.15.0", optional = true } tokio-rustls = { version = "0.22.0", optional = true }
# compression # compression
brotli2 = { version="0.3.2", optional = true } brotli2 = { version="0.3.2", optional = true }
flate2 = { version = "1.0.14", optional = true } flate2 = { version = "1.0.20", optional = true }
[dev-dependencies] [dev-dependencies]
env_logger = "0.8" env_logger = "0.8"
serde_derive = "1.0" serde_derive = "1.0"
open-ssl = { version="0.10", package = "openssl" } open-ssl = { version="0.10", package = "openssl" }
rust-tls = { version = "0.19.0", package="rustls", features = ["dangerous_configuration"] } rust-tls = { version = "0.19.0", package="rustls", features = ["dangerous_configuration"] }
webpki = "0.21.2" webpki = "0.21.4"

View file

@ -6,7 +6,7 @@ use std::net::SocketAddr;
use either::Either; use either::Either;
/// Connect request /// Connect request
pub trait Address: Unpin { pub trait Address: Unpin + 'static {
/// Host name of the request /// Host name of the request
fn host(&self) -> &str; fn host(&self) -> &str;

View file

@ -14,23 +14,23 @@ pub mod openssl;
pub mod rustls; pub mod rustls;
pub use trust_dns_resolver::config::{self, ResolverConfig, ResolverOpts}; pub use trust_dns_resolver::config::{self, ResolverConfig, ResolverOpts};
pub use trust_dns_resolver::error::ResolveError;
use trust_dns_resolver::system_conf::read_system_conf; use trust_dns_resolver::system_conf::read_system_conf;
pub use trust_dns_resolver::{error::ResolveError, TokioAsyncResolver as DnsResolver};
use crate::rt::{net::TcpStream, Arbiter}; use crate::rt::{net::TcpStream, Arbiter};
pub use self::error::ConnectError; pub use self::error::ConnectError;
pub use self::message::{Address, Connect}; pub use self::message::{Address, Connect};
pub use self::resolve::{AsyncResolver, Resolver}; pub use self::resolve::Resolver;
pub use self::service::Connector; pub use self::service::Connector;
pub fn start_resolver(cfg: ResolverConfig, opts: ResolverOpts) -> AsyncResolver { pub fn start_resolver(cfg: ResolverConfig, opts: ResolverOpts) -> DnsResolver {
AsyncResolver::new(cfg, opts) DnsResolver::tokio(cfg, opts).unwrap()
} }
struct DefaultResolver(AsyncResolver); struct DefaultResolver(DnsResolver);
pub fn default_resolver() -> AsyncResolver { pub fn default_resolver() -> DnsResolver {
if Arbiter::contains_item::<DefaultResolver>() { if Arbiter::contains_item::<DefaultResolver>() {
Arbiter::get_item(|item: &DefaultResolver| item.0.clone()) Arbiter::get_item(|item: &DefaultResolver| item.0.clone())
} else { } else {
@ -42,7 +42,7 @@ pub fn default_resolver() -> AsyncResolver {
} }
}; };
let resolver = AsyncResolver::new(cfg, opts); let resolver = DnsResolver::tokio(cfg, opts).unwrap();
Arbiter::set_item(DefaultResolver(resolver.clone())); Arbiter::set_item(DefaultResolver(resolver.clone()));
resolver resolver
@ -50,13 +50,12 @@ pub fn default_resolver() -> AsyncResolver {
} }
/// Resolve and connect to remote host /// Resolve and connect to remote host
pub fn connect<T: Address, U>( pub fn connect<T, U>(message: U) -> impl Future<Output = Result<TcpStream, ConnectError>>
message: U,
) -> impl Future<Output = Result<TcpStream, ConnectError>>
where where
T: Address + 'static,
Connect<T>: From<U>, Connect<T>: From<U>,
{ {
service::ConnectServiceResponse::new( service::ConnectServiceResponse::new(Box::pin(
Resolver::new(default_resolver()).lookup(message.into()), Resolver::new(default_resolver()).lookup(message.into()),
) ))
} }

View file

@ -1,15 +1,13 @@
use std::future::Future; use std::{io, pin::Pin, task::Context, task::Poll};
use std::io;
use std::task::{Context, Poll};
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, Future, FutureExt, LocalBoxFuture, Ready};
pub use open_ssl::ssl::{Error as SslError, SslConnector, SslMethod}; pub use open_ssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
pub use tokio_openssl::{HandshakeError, SslStream}; pub use tokio_openssl::SslStream;
use crate::rt::net::TcpStream; use crate::rt::net::TcpStream;
use crate::service::{Service, ServiceFactory}; use crate::service::{Service, ServiceFactory};
use super::{Address, AsyncResolver, Connect, ConnectError, Connector}; use super::{Address, Connect, ConnectError, Connector, DnsResolver};
pub struct OpensslConnector<T> { pub struct OpensslConnector<T> {
connector: Connector<T>, connector: Connector<T>,
@ -26,7 +24,7 @@ impl<T> OpensslConnector<T> {
} }
/// Construct new connect service with custom dns resolver /// Construct new connect service with custom dns resolver
pub fn with_resolver(connector: SslConnector, resolver: AsyncResolver) -> Self { pub fn with_resolver(connector: SslConnector, resolver: DnsResolver) -> Self {
OpensslConnector { OpensslConnector {
connector: Connector::new(resolver), connector: Connector::new(resolver),
openssl: connector, openssl: connector,
@ -54,8 +52,14 @@ impl<T: Address + 'static> OpensslConnector<T> {
match openssl.configure() { match openssl.configure() {
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()), Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()),
Ok(config) => match tokio_openssl::connect(config, &host, io).await { Ok(config) => {
Ok(io) => { let config = config
.into_ssl(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let mut io = SslStream::new(config, io)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
match Pin::new(&mut io).connect().await {
Ok(_) => {
trace!("SSL Handshake success: {:?}", host); trace!("SSL Handshake success: {:?}", host);
Ok(io) Ok(io)
} }
@ -64,7 +68,8 @@ impl<T: Address + 'static> OpensslConnector<T> {
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)) Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))
.into()) .into())
} }
}, }
}
} }
} }
} }

View file

@ -1,33 +1,16 @@
use std::cell::RefCell; use std::{
use std::future::Future; fmt, marker::PhantomData, net::SocketAddr, pin::Pin, rc::Rc, task::Context,
use std::marker::PhantomData; task::Poll,
use std::net::SocketAddr;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use std::{fmt, io};
use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
use futures::ready;
use trust_dns_proto::{error::ProtoError, Time};
use trust_dns_resolver::config::{ResolverConfig, ResolverOpts};
use trust_dns_resolver::error::ResolveError;
use trust_dns_resolver::lookup_ip::LookupIp;
use trust_dns_resolver::name_server::{
GenericConnection, GenericConnectionProvider, RuntimeProvider, Spawn,
}; };
use trust_dns_resolver::AsyncResolver as TAsyncResolver;
use crate::channel::condition::{Condition, Waiter}; use futures::future::{ok, Either, Future, Ready};
use crate::rt::net::{self, TcpStream};
use super::{default_resolver, Address, Connect, ConnectError, DnsResolver};
use crate::service::{Service, ServiceFactory}; use crate::service::{Service, ServiceFactory};
use super::{default_resolver, Address, Connect, ConnectError};
/// DNS Resolver Service /// DNS Resolver Service
pub struct Resolver<T> { pub struct Resolver<T> {
resolver: AsyncResolver, resolver: Rc<DnsResolver>,
_t: PhantomData<T>, _t: PhantomData<T>,
} }
@ -41,9 +24,9 @@ impl<T> fmt::Debug for Resolver<T> {
impl<T> Resolver<T> { impl<T> Resolver<T> {
/// Create new resolver instance with custom configuration and options. /// Create new resolver instance with custom configuration and options.
pub fn new(resolver: AsyncResolver) -> Self { pub fn new(resolver: DnsResolver) -> Self {
Resolver { Resolver {
resolver, resolver: Rc::new(resolver),
_t: PhantomData, _t: PhantomData,
} }
} }
@ -54,7 +37,7 @@ impl<T: Address> Resolver<T> {
pub fn lookup( pub fn lookup(
&self, &self,
mut req: Connect<T>, mut req: Connect<T>,
) -> Either<ResolverFuture<T>, Ready<Result<Connect<T>, ConnectError>>> { ) -> impl Future<Output = Result<Connect<T>, ConnectError>> {
if req.addr.is_some() || req.req.addr().is_some() { if req.addr.is_some() || req.req.addr().is_some() {
Either::Right(ok(req)) Either::Right(ok(req))
} else if let Ok(ip) = req.host().parse() { } else if let Ok(ip) = req.host().parse() {
@ -62,7 +45,43 @@ impl<T: Address> Resolver<T> {
Either::Right(ok(req)) Either::Right(ok(req))
} else { } else {
trace!("DNS resolver: resolving host {:?}", req.host()); trace!("DNS resolver: resolving host {:?}", req.host());
Either::Left(ResolverFuture::new(req, &self.resolver)) let resolver = self.resolver.clone();
Either::Left(async move {
let fut = if let Some(host) = req.host().splitn(2, ':').next() {
resolver.lookup_ip(host)
} else {
resolver.lookup_ip(req.host())
};
match fut.await {
Ok(ips) => {
let port = req.port();
let req = req
.set_addrs(ips.iter().map(|ip| SocketAddr::new(ip, port)));
trace!(
"DNS resolver: host {:?} resolved to {:?}",
req.host(),
req.addrs()
);
if req.addr.is_none() {
Err(ConnectError::NoRecords)
} else {
Ok(req)
}
}
Err(e) => {
trace!(
"DNS resolver: failed to resolve host {:?} err: {}",
req.host(),
e
);
Err(e.into())
}
}
})
} }
} }
} }
@ -70,7 +89,7 @@ impl<T: Address> Resolver<T> {
impl<T> Default for Resolver<T> { impl<T> Default for Resolver<T> {
fn default() -> Resolver<T> { fn default() -> Resolver<T> {
Resolver { Resolver {
resolver: default_resolver(), resolver: Rc::new(default_resolver()),
_t: PhantomData, _t: PhantomData,
} }
} }
@ -103,7 +122,7 @@ impl<T: Address> Service for Resolver<T> {
type Request = Connect<T>; type Request = Connect<T>;
type Response = Connect<T>; type Response = Connect<T>;
type Error = ConnectError; type Error = ConnectError;
type Future = Either<ResolverFuture<T>, Ready<Result<Connect<T>, Self::Error>>>; type Future = Pin<Box<dyn Future<Output = Result<Connect<T>, Self::Error>>>>;
#[inline] #[inline]
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -112,291 +131,7 @@ impl<T: Address> Service for Resolver<T> {
#[inline] #[inline]
fn call(&self, req: Connect<T>) -> Self::Future { fn call(&self, req: Connect<T>) -> Self::Future {
self.lookup(req) Box::pin(self.lookup(req))
}
}
#[doc(hidden)]
/// Resolver future
pub struct ResolverFuture<T: Address> {
req: Option<Connect<T>>,
lookup: LookupIpFuture,
}
impl<T: Address> ResolverFuture<T> {
pub fn new(req: Connect<T>, resolver: &AsyncResolver) -> Self {
let lookup = if let Some(host) = req.host().splitn(2, ':').next() {
resolver.lookup_ip(host)
} else {
resolver.lookup_ip(req.host())
};
ResolverFuture {
lookup,
req: Some(req),
}
}
}
impl<T: Address> Future for ResolverFuture<T> {
type Output = Result<Connect<T>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match Pin::new(&mut this.lookup).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(ips)) => {
let req = this.req.take().unwrap();
let port = req.port();
let req = req.set_addrs(ips.iter().map(|ip| SocketAddr::new(ip, port)));
trace!(
"DNS resolver: host {:?} resolved to {:?}",
req.host(),
req.addrs()
);
if req.addr.is_none() {
Poll::Ready(Err(ConnectError::NoRecords))
} else {
Poll::Ready(Ok(req))
}
}
Poll::Ready(Err(e)) => {
trace!(
"DNS resolver: failed to resolve host {:?} err: {}",
this.req.as_ref().unwrap().host(),
e
);
Poll::Ready(Err(e.into()))
}
}
}
}
#[derive(Clone, Debug)]
/// An asynchronous resolver for DNS.
pub struct AsyncResolver {
state: Rc<RefCell<AsyncResolverState>>,
}
impl AsyncResolver {
/// Construct a new `AsyncResolver` with the provided configuration.
///
/// # Arguments
///
/// * `config` - configuration, name_servers, etc. for the Resolver
/// * `options` - basic lookup options for the resolver
pub fn new(config: ResolverConfig, options: ResolverOpts) -> Self {
AsyncResolver {
state: Rc::new(RefCell::new(AsyncResolverState::New(Some(
TAsyncResolver::new(config, options, Handle).boxed_local(),
)))),
}
}
/// Constructs a new Resolver with the system configuration.
///
/// This will use `/etc/resolv.conf` on Unix OSes and the registry on Windows.
pub fn from_system_conf() -> Self {
AsyncResolver {
state: Rc::new(RefCell::new(AsyncResolverState::New(Some(
TokioAsyncResolver::from_system_conf(Handle).boxed_local(),
)))),
}
}
pub fn lookup_ip(&self, host: &str) -> LookupIpFuture {
LookupIpFuture {
host: host.to_string(),
state: self.state.clone(),
fut: LookupIpState::Init,
}
}
}
type TokioAsyncResolver =
TAsyncResolver<GenericConnection, GenericConnectionProvider<TokioRuntime>>;
enum AsyncResolverState {
New(Option<LocalBoxFuture<'static, Result<TokioAsyncResolver, ResolveError>>>),
Creating(Condition),
Resolver(Box<TokioAsyncResolver>),
}
impl fmt::Debug for AsyncResolverState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AsyncResolverState::New(_) => write!(f, "AsyncResolverState::New"),
AsyncResolverState::Creating(_) => write!(f, "AsyncResolverState::Creating"),
AsyncResolverState::Resolver(_) => write!(f, "AsyncResolverState::Resolver"),
}
}
}
pub struct LookupIpFuture {
host: String,
state: Rc<RefCell<AsyncResolverState>>,
fut: LookupIpState,
}
enum LookupIpState {
Init,
Create(LocalBoxFuture<'static, Result<TokioAsyncResolver, ResolveError>>),
Wait(Waiter),
Lookup(LocalBoxFuture<'static, Result<LookupIp, ResolveError>>),
}
impl Future for LookupIpFuture {
type Output = Result<LookupIp, ResolveError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().get_mut();
loop {
match this.fut {
LookupIpState::Lookup(ref mut fut) => return Pin::new(fut).poll(cx),
LookupIpState::Create(ref mut fut) => {
let resolver = ready!(Pin::new(fut).poll(cx))?;
this.fut = LookupIpState::Init;
*this.state.borrow_mut() =
AsyncResolverState::Resolver(Box::new(resolver));
}
LookupIpState::Wait(ref mut waiter) => {
ready!(waiter.poll_waiter(cx));
this.fut = LookupIpState::Init;
}
LookupIpState::Init => {
let mut state = this.state.borrow_mut();
match &mut *state {
AsyncResolverState::New(ref mut fut) => {
this.fut = LookupIpState::Create(fut.take().unwrap());
*state = AsyncResolverState::Creating(Condition::default());
}
AsyncResolverState::Creating(ref cond) => {
this.fut = LookupIpState::Wait(cond.wait());
}
AsyncResolverState::Resolver(ref resolver) => {
let host = this.host.clone();
let resolver: TokioAsyncResolver = Clone::clone(resolver);
this.fut = LookupIpState::Lookup(
async move { resolver.lookup_ip(host.as_str()).await }
.boxed_local(),
);
}
}
}
}
}
}
}
#[derive(Clone, Copy)]
struct Handle;
impl Spawn for Handle {
fn spawn_bg<F>(&mut self, future: F)
where
F: Future<Output = Result<(), ProtoError>> + Send + 'static,
{
crate::rt::spawn(future.map(|_| ()));
}
}
struct UdpSocket(net::UdpSocket);
#[derive(Clone)]
struct TokioRuntime;
impl RuntimeProvider for TokioRuntime {
type Handle = Handle;
type Tcp = AsyncIo02As03<TcpStream>;
type Timer = TokioTime;
type Udp = UdpSocket;
}
/// Conversion from `tokio::io::{AsyncRead, AsyncWrite}` to `std::io::{AsyncRead, AsyncWrite}`
struct AsyncIo02As03<T>(T);
use crate::codec::{AsyncRead as AsyncRead02, AsyncWrite as AsyncWrite02};
use futures::io::{AsyncRead, AsyncWrite};
impl<T> Unpin for AsyncIo02As03<T> {}
impl<R: AsyncRead02 + Unpin> AsyncRead for AsyncIo02As03<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl<W: AsyncWrite02 + Unpin> AsyncWrite for AsyncIo02As03<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
#[async_trait::async_trait]
impl trust_dns_proto::tcp::Connect for AsyncIo02As03<TcpStream> {
type Transport = AsyncIo02As03<TcpStream>;
async fn connect(addr: SocketAddr) -> io::Result<Self::Transport> {
TcpStream::connect(&addr).await.map(AsyncIo02As03)
}
}
#[async_trait::async_trait]
impl trust_dns_proto::udp::UdpSocket for UdpSocket {
async fn bind(addr: &SocketAddr) -> io::Result<Self> {
net::UdpSocket::bind(addr).await.map(UdpSocket)
}
async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.0.recv_from(buf).await
}
async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
self.0.send_to(buf, target).await
}
}
/// New type which is implemented using tokio::time::{Delay, Timeout}
struct TokioTime;
#[async_trait::async_trait]
impl Time for TokioTime {
async fn delay_for(duration: std::time::Duration) {
tokio::time::delay_for(duration).await
}
async fn timeout<F: 'static + Future + Send>(
duration: std::time::Duration,
future: F,
) -> Result<F::Output, std::io::Error> {
tokio::time::timeout(duration, future)
.await
.map_err(move |_| {
std::io::Error::new(std::io::ErrorKind::TimedOut, "future timed out")
})
} }
} }
@ -408,7 +143,7 @@ mod tests {
#[ntex_rt::test] #[ntex_rt::test]
async fn resolver() { async fn resolver() {
let resolver = Resolver::new(AsyncResolver::from_system_conf()); let resolver = Resolver::new(DnsResolver::tokio_from_system_conf().unwrap());
assert!(format!("{:?}", resolver).contains("Resolver")); assert!(format!("{:?}", resolver).contains("Resolver"));
let srv = resolver.new_service(()).await.unwrap(); let srv = resolver.new_service(()).await.unwrap();
assert!(lazy(|cx| srv.poll_ready(cx)).await.is_ready()); assert!(lazy(|cx| srv.poll_ready(cx)).await.is_ready());

View file

@ -12,7 +12,7 @@ use webpki::DNSNameRef;
use crate::rt::net::TcpStream; use crate::rt::net::TcpStream;
use crate::service::{Service, ServiceFactory}; use crate::service::{Service, ServiceFactory};
use super::{Address, AsyncResolver, Connect, ConnectError, Connector}; use super::{Address, Connect, ConnectError, Connector, DnsResolver};
/// Rustls connector factory /// Rustls connector factory
pub struct RustlsConnector<T> { pub struct RustlsConnector<T> {
@ -29,7 +29,7 @@ impl<T> RustlsConnector<T> {
} }
/// Construct new connect service with custom dns resolver /// Construct new connect service with custom dns resolver
pub fn with_resolver(config: Arc<ClientConfig>, resolver: AsyncResolver) -> Self { pub fn with_resolver(config: Arc<ClientConfig>, resolver: DnsResolver) -> Self {
RustlsConnector { RustlsConnector {
config, config,
connector: Connector::new(resolver), connector: Connector::new(resolver),

View file

@ -7,7 +7,7 @@ use futures::future::{ok, Future, FutureExt, LocalBoxFuture, Ready};
use crate::rt::net::TcpStream; use crate::rt::net::TcpStream;
use crate::service::{Service, ServiceFactory}; use crate::service::{Service, ServiceFactory};
use super::{Address, AsyncResolver, Connect, ConnectError, Resolver}; use super::{Address, Connect, ConnectError, DnsResolver, Resolver};
pub struct Connector<T> { pub struct Connector<T> {
resolver: Resolver<T>, resolver: Resolver<T>,
@ -15,7 +15,7 @@ pub struct Connector<T> {
impl<T> Connector<T> { impl<T> Connector<T> {
/// Construct new connect service with custom dns resolver /// Construct new connect service with custom dns resolver
pub fn new(resolver: AsyncResolver) -> Self { pub fn new(resolver: DnsResolver) -> Self {
Connector { Connector {
resolver: Resolver::new(resolver), resolver: Resolver::new(resolver),
} }
@ -31,7 +31,7 @@ impl<T: Address> Connector<T> {
where where
Connect<T>: From<U>, Connect<T>: From<U>,
{ {
ConnectServiceResponse::new(self.resolver.lookup(message.into())) ConnectServiceResponse::new(self.resolver.call(message.into()))
} }
} }
@ -79,7 +79,7 @@ impl<T: Address> Service for Connector<T> {
#[inline] #[inline]
fn call(&self, req: Connect<T>) -> Self::Future { fn call(&self, req: Connect<T>) -> Self::Future {
ConnectServiceResponse::new(self.resolver.lookup(req)) ConnectServiceResponse::new(self.resolver.call(req))
} }
} }

View file

@ -61,7 +61,7 @@ where
} }
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone, Debug)]
pub(super) enum ReadResult { pub(super) enum ReadResult {
Pending, Pending,
Updated, Updated,
@ -85,7 +85,7 @@ where
// read all data from socket // read all data from socket
let mut result = ReadResult::Pending; let mut result = ReadResult::Pending;
loop { loop {
match Pin::new(&mut *io).poll_read_buf(cx, buf) { match crate::codec::poll_read_buf(Pin::new(&mut *io), cx, buf) {
Poll::Pending => break, Poll::Pending => break,
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {

View file

@ -466,8 +466,11 @@ impl State {
Ok(None) => { Ok(None) => {
let st = self.0.clone(); let st = self.0.clone();
let n = poll_fn(|cx| { let n = poll_fn(|cx| {
Pin::new(&mut *io) crate::codec::poll_read_buf(
.poll_read_buf(cx, &mut *st.read_buf.borrow_mut()) Pin::new(&mut *io),
cx,
&mut *st.read_buf.borrow_mut(),
)
}) })
.await .await
.map_err(Either::Right)?; .map_err(Either::Right)?;
@ -502,7 +505,11 @@ impl State {
return match codec.decode(&mut buf) { return match codec.decode(&mut buf) {
Ok(Some(el)) => Poll::Ready(Ok(Some(el))), Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
Ok(None) => { Ok(None) => {
let n = ready!(Pin::new(&mut *io).poll_read_buf(cx, &mut *buf)) let n = ready!(crate::codec::poll_read_buf(
Pin::new(&mut *io),
cx,
&mut *buf
))
.map_err(Either::Right)?; .map_err(Either::Right)?;
if n == 0 { if n == 0 {
Poll::Ready(Ok(None)) Poll::Ready(Ok(None))

View file

@ -3,16 +3,16 @@ use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc, time::Duration};
use bytes::{Buf, BytesMut}; use bytes::{Buf, BytesMut};
use crate::codec::{AsyncRead, AsyncWrite}; use crate::codec::{AsyncRead, AsyncWrite, ReadBuf};
use crate::framed::State; use crate::framed::State;
use crate::rt::time::{delay_for, Delay}; use crate::rt::time::{sleep, Sleep};
const HW: usize = 16 * 1024; const HW: usize = 16 * 1024;
#[derive(Debug)] #[derive(Debug)]
enum IoWriteState { enum IoWriteState {
Processing, Processing,
Shutdown(Option<Delay>, Shutdown), Shutdown(Option<Pin<Box<Sleep>>>, Shutdown),
} }
#[derive(Debug)] #[derive(Debug)]
@ -50,7 +50,7 @@ where
let disconnect_timeout = state.get_disconnect_timeout() as u64; let disconnect_timeout = state.get_disconnect_timeout() as u64;
let st = IoWriteState::Shutdown( let st = IoWriteState::Shutdown(
if disconnect_timeout != 0 { if disconnect_timeout != 0 {
Some(delay_for(Duration::from_millis(disconnect_timeout))) Some(Box::pin(sleep(Duration::from_millis(disconnect_timeout))))
} else { } else {
None None
}, },
@ -87,7 +87,9 @@ where
let disconnect_timeout = this.state.get_disconnect_timeout() as u64; let disconnect_timeout = this.state.get_disconnect_timeout() as u64;
this.st = IoWriteState::Shutdown( this.st = IoWriteState::Shutdown(
if disconnect_timeout != 0 { if disconnect_timeout != 0 {
Some(delay_for(Duration::from_millis(disconnect_timeout))) Some(Box::pin(sleep(Duration::from_millis(
disconnect_timeout,
))))
} else { } else {
None None
}, },
@ -161,10 +163,13 @@ where
Shutdown::Shutdown => { Shutdown::Shutdown => {
// read until 0 or err // read until 0 or err
let mut buf = [0u8; 512]; let mut buf = [0u8; 512];
let mut read_buf = ReadBuf::new(&mut buf);
let mut io = this.io.borrow_mut(); let mut io = this.io.borrow_mut();
loop { loop {
match Pin::new(&mut *io).poll_read(cx, &mut buf) { match Pin::new(&mut *io).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(0)) | Poll::Ready(Err(_)) => { Poll::Ready(Err(_)) | Poll::Ready(Ok(_))
if read_buf.filled().is_empty() =>
{
this.state.set_wr_shutdown_complete(); this.state.set_wr_shutdown_complete();
log::trace!("write task is stopped"); log::trace!("write task is stopped");
return Poll::Ready(()); return Poll::Ready(());

View file

@ -1,9 +1,6 @@
use std::future::Future; use std::{fmt, future::Future, io, net, pin::Pin, task::Context, task::Poll};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io, mem, net};
use crate::codec::{AsyncRead, AsyncWrite, Framed}; use crate::codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
use crate::http::body::Body; use crate::http::body::Body;
use crate::http::h1::ClientCodec; use crate::http::h1::ClientCodec;
use crate::http::{RequestHeadType, ResponseHead}; use crate::http::{RequestHeadType, ResponseHead};
@ -133,18 +130,11 @@ impl fmt::Debug for BoxedSocket {
} }
impl AsyncRead for BoxedSocket { impl AsyncRead for BoxedSocket {
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [mem::MaybeUninit<u8>],
) -> bool {
self.0.as_read().prepare_uninitialized_buffer(buf)
}
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut [u8], buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf) Pin::new(self.get_mut().0.as_read_mut()).poll_read(cx, buf)
} }
} }

View file

@ -48,7 +48,7 @@ pub struct Connector {
connector: BoxedConnector, connector: BoxedConnector,
ssl_connector: Option<BoxedConnector>, ssl_connector: Option<BoxedConnector>,
#[allow(dead_code)] #[allow(dead_code)]
resolver: connect::AsyncResolver, resolver: connect::DnsResolver,
} }
trait Io: AsyncRead + AsyncWrite + Unpin {} trait Io: AsyncRead + AsyncWrite + Unpin {}
@ -61,7 +61,7 @@ impl Default for Connector {
} }
impl Connector { impl Connector {
pub fn new(resolver: connect::AsyncResolver) -> Connector { pub fn new(resolver: connect::DnsResolver) -> Connector {
let conn = Connector { let conn = Connector {
connector: boxed::service( connector: boxed::service(
TcpConnector::new(resolver.clone()) TcpConnector::new(resolver.clone())

View file

@ -1,14 +1,9 @@
use std::io::Write; use std::{io, io::Write, pin::Pin, task::Context, task::Poll, time};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem, time};
use bytes::buf::BufMutExt; use bytes::{BufMut, Bytes, BytesMut};
use bytes::{Bytes, BytesMut}; use futures::{future::poll_fn, SinkExt, Stream, StreamExt};
use futures::future::poll_fn;
use futures::{SinkExt, Stream, StreamExt};
use crate::codec::{AsyncRead, AsyncWrite, Framed}; use crate::codec::{AsyncRead, AsyncWrite, Framed, ReadBuf};
use crate::http::body::{BodySize, MessageBody}; use crate::http::body::{BodySize, MessageBody};
use crate::http::error::PayloadError; use crate::http::error::PayloadError;
use crate::http::h1; use crate::http::h1;
@ -199,18 +194,11 @@ where
} }
impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for H1Connection<T> { impl<T: AsyncRead + AsyncWrite + Unpin + 'static> AsyncRead for H1Connection<T> {
unsafe fn prepare_uninitialized_buffer(
&self,
buf: &mut [mem::MaybeUninit<u8>],
) -> bool {
self.io.as_ref().unwrap().prepare_uninitialized_buffer(buf)
}
fn poll_read( fn poll_read(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut [u8], buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<()>> {
Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf) Pin::new(&mut self.io.as_mut().unwrap()).poll_read(cx, buf)
} }
} }

View file

@ -8,7 +8,7 @@ use h2::client::{handshake, Connection, SendRequest};
use http::uri::Authority; use http::uri::Authority;
use crate::channel::pool; use crate::channel::pool;
use crate::codec::{AsyncRead, AsyncWrite}; use crate::codec::{AsyncRead, AsyncWrite, ReadBuf};
use crate::http::Protocol; use crate::http::Protocol;
use crate::rt::{spawn, time::delay_for, time::Delay}; use crate::rt::{spawn, time::delay_for, time::Delay};
use crate::service::Service; use crate::service::Service;
@ -255,10 +255,11 @@ where
} else { } else {
let mut io = conn.io; let mut io = conn.io;
let mut buf = [0; 2]; let mut buf = [0; 2];
let mut read_buf = ReadBuf::new(&mut buf);
if let ConnectionType::H1(ref mut s) = io { if let ConnectionType::H1(ref mut s) = io {
match Pin::new(s).poll_read(cx, &mut buf) { match Pin::new(s).poll_read(cx, &mut read_buf) {
Poll::Pending => (), Poll::Pending => (),
Poll::Ready(Ok(n)) if n > 0 => { Poll::Ready(Ok(_)) if !read_buf.filled().is_empty() => {
if let ConnectionType::H1(io) = io { if let ConnectionType::H1(io) = io {
CloseConnection::spawn(io, self.disconnect_timeout); CloseConnection::spawn(io, self.disconnect_timeout);
} }
@ -368,7 +369,7 @@ where
struct CloseConnection<T> { struct CloseConnection<T> {
io: T, io: T,
timeout: Option<Delay>, timeout: Option<Pin<Box<Delay>>>,
shutdown: bool, shutdown: bool,
} }
@ -378,7 +379,7 @@ where
{ {
fn spawn(io: T, timeout: Duration) { fn spawn(io: T, timeout: Duration) {
let timeout = if timeout != ZERO { let timeout = if timeout != ZERO {
Some(delay_for(timeout)) Some(Box::pin(delay_for(timeout)))
} else { } else {
None None
}; };
@ -412,12 +413,13 @@ where
Poll::Ready(_) => (), Poll::Ready(_) => (),
Poll::Pending => { Poll::Pending => {
let mut buf = [0u8; 512]; let mut buf = [0u8; 512];
let mut read_buf = ReadBuf::new(&mut buf);
loop { loop {
match Pin::new(&mut this.io).poll_read(cx, &mut buf) { match Pin::new(&mut this.io).poll_read(cx, &mut read_buf) {
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
Poll::Ready(Err(_)) => return Poll::Ready(()), Poll::Ready(Err(_)) => return Poll::Ready(()),
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(_)) => {
if n == 0 { if read_buf.filled().is_empty() {
return Poll::Ready(()); return Poll::Ready(());
} }
continue; continue;

View file

@ -54,7 +54,7 @@ impl From<PrepForSendingError> for SendRequestError {
pub enum SendClientRequest { pub enum SendClientRequest {
Fut( Fut(
Pin<Box<dyn Future<Output = Result<ClientResponse, SendRequestError>>>>, Pin<Box<dyn Future<Output = Result<ClientResponse, SendRequestError>>>>,
Option<Delay>, Option<Pin<Box<Delay>>>,
bool, bool,
), ),
Err(Option<SendRequestError>), Err(Option<SendRequestError>),
@ -66,7 +66,7 @@ impl SendClientRequest {
response_decompress: bool, response_decompress: bool,
timeout: Option<Duration>, timeout: Option<Duration>,
) -> SendClientRequest { ) -> SendClientRequest {
let delay = timeout.map(delay_for); let delay = timeout.map(|d| Box::pin(delay_for(d)));
SendClientRequest::Fut(send, delay, response_decompress) SendClientRequest::Fut(send, delay, response_decompress)
} }
} }

View file

@ -116,7 +116,7 @@ pub(super) trait MessageType: Sized {
let mut pos = 0; let mut pos = 0;
let mut has_date = false; let mut has_date = false;
let mut remaining = dst.capacity() - dst.len(); let mut remaining = dst.capacity() - dst.len();
let mut buf = dst.bytes_mut().as_mut_ptr() as *mut u8; let mut buf = dst.chunk_mut().as_mut_ptr() as *mut u8;
for (key, value) in headers { for (key, value) in headers {
match *key { match *key {
CONNECTION => continue, CONNECTION => continue,
@ -140,7 +140,7 @@ pub(super) trait MessageType: Sized {
pos = 0; pos = 0;
dst.reserve(len * 2); dst.reserve(len * 2);
remaining = dst.capacity() - dst.len(); remaining = dst.capacity() - dst.len();
buf = dst.bytes_mut().as_mut_ptr() as *mut u8; buf = dst.chunk_mut().as_mut_ptr() as *mut u8;
} }
copy_nonoverlapping(k.as_ptr(), buf, k_len); copy_nonoverlapping(k.as_ptr(), buf, k_len);
buf = buf.add(k_len); buf = buf.add(k_len);
@ -167,7 +167,7 @@ pub(super) trait MessageType: Sized {
pos = 0; pos = 0;
dst.reserve(len * 2); dst.reserve(len * 2);
remaining = dst.capacity() - dst.len(); remaining = dst.capacity() - dst.len();
buf = dst.bytes_mut().as_mut_ptr() as *mut u8; buf = dst.chunk_mut().as_mut_ptr() as *mut u8;
} }
copy_nonoverlapping(k.as_ptr(), buf, k_len); copy_nonoverlapping(k.as_ptr(), buf, k_len);
buf = buf.add(k_len); buf = buf.add(k_len);

View file

@ -1,6 +1,4 @@
use std::sync::mpsc as sync_mpsc; use std::{io, sync::mpsc as sync_mpsc, sync::Arc, thread, time::Duration};
use std::time::Duration;
use std::{io, thread};
use log::{error, info}; use log::{error, info};
use slab::Slab; use slab::Slab;
@ -8,94 +6,89 @@ use slab::Slab;
use crate::rt::time::{delay_until, Instant}; use crate::rt::time::{delay_until, Instant};
use crate::rt::System; use crate::rt::System;
use super::socket::{SocketAddr, SocketListener, StdListener}; use super::socket::{Listener, SocketAddr};
use super::worker::{Conn, WorkerClient}; use super::worker::{Conn, WorkerClient};
use super::{Server, Token}; use super::{Server, Token};
const DELTA: usize = 100;
const NOTIFY: mio::Token = mio::Token(0);
#[derive(Debug)]
pub(super) enum Command { pub(super) enum Command {
Pause, Pause,
Resume, Resume,
Stop, Stop,
Worker(WorkerClient), Worker(WorkerClient),
Timer,
WorkerAvailable,
} }
struct ServerSocketInfo { struct ServerSocketInfo {
addr: SocketAddr, addr: SocketAddr,
token: Token, token: Token,
sock: SocketListener, sock: Listener,
timeout: Option<Instant>, timeout: Option<Instant>,
} }
#[derive(Clone)] #[derive(Debug, Clone)]
pub(super) struct AcceptNotify(mio::SetReadiness); pub(super) struct AcceptNotify(Arc<mio::Waker>, sync_mpsc::Sender<Command>);
impl AcceptNotify { impl AcceptNotify {
pub(super) fn new(ready: mio::SetReadiness) -> Self { pub(super) fn new(waker: Arc<mio::Waker>, tx: sync_mpsc::Sender<Command>) -> Self {
AcceptNotify(ready) AcceptNotify(waker, tx)
} }
pub(super) fn notify(&self) { pub(super) fn send(&self, cmd: Command) {
let _ = self.0.set_readiness(mio::Ready::readable()); let _ = self.1.send(cmd);
} let _ = self.0.wake();
}
impl Default for AcceptNotify {
fn default() -> Self {
AcceptNotify::new(mio::Registration::new2().1)
} }
} }
pub(super) struct AcceptLoop { pub(super) struct AcceptLoop {
cmd_reg: Option<mio::Registration>, notify: AcceptNotify,
cmd_ready: mio::SetReadiness, inner: Option<(sync_mpsc::Receiver<Command>, mio::Poll, Server)>,
notify_reg: Option<mio::Registration>,
notify_ready: mio::SetReadiness,
tx: sync_mpsc::Sender<Command>,
rx: Option<sync_mpsc::Receiver<Command>>,
srv: Option<Server>,
} }
impl AcceptLoop { impl AcceptLoop {
pub(super) fn new(srv: Server) -> AcceptLoop { pub(super) fn new(srv: Server) -> AcceptLoop {
// Create a poll instance
let poll = mio::Poll::new()
.map_err(|e| panic!("Can not create mio::Poll {}", e))
.unwrap();
let (tx, rx) = sync_mpsc::channel(); let (tx, rx) = sync_mpsc::channel();
let (cmd_reg, cmd_ready) = mio::Registration::new2(); let waker = Arc::new(
let (notify_reg, notify_ready) = mio::Registration::new2(); mio::Waker::new(poll.registry(), NOTIFY)
.map_err(|e| panic!("Can not create mio::Waker {}", e))
.unwrap(),
);
let notify = AcceptNotify::new(waker, tx);
AcceptLoop { AcceptLoop {
tx, notify,
cmd_ready, inner: Some((rx, poll, srv)),
cmd_reg: Some(cmd_reg),
notify_ready,
notify_reg: Some(notify_reg),
rx: Some(rx),
srv: Some(srv),
} }
} }
pub(super) fn send(&self, msg: Command) { pub(super) fn send(&self, msg: Command) {
let _ = self.tx.send(msg); self.notify.send(msg)
let _ = self.cmd_ready.set_readiness(mio::Ready::readable());
} }
pub(super) fn get_notify(&self) -> AcceptNotify { pub(super) fn notify(&self) -> AcceptNotify {
AcceptNotify::new(self.notify_ready.clone()) self.notify.clone()
} }
pub(super) fn start( pub(super) fn start(
&mut self, &mut self,
socks: Vec<(Token, StdListener)>, socks: Vec<(Token, Listener)>,
workers: Vec<WorkerClient>, workers: Vec<WorkerClient>,
) { ) {
let srv = self.srv.take().expect("Can not re-use AcceptInfo"); let (rx, poll, srv) = self
.inner
.take()
.expect("AcceptLoop cannot be used multiple times");
Accept::start( Accept::start(rx, poll, socks, srv, workers, self.notify.clone());
self.rx.take().expect("Can not re-use AcceptInfo"),
self.cmd_reg.take().expect("Can not re-use AcceptInfo"),
self.notify_reg.take().expect("Can not re-use AcceptInfo"),
socks,
srv,
workers,
);
} }
} }
@ -105,16 +98,11 @@ struct Accept {
sockets: Slab<ServerSocketInfo>, sockets: Slab<ServerSocketInfo>,
workers: Vec<WorkerClient>, workers: Vec<WorkerClient>,
srv: Server, srv: Server,
timer: (mio::Registration, mio::SetReadiness), notify: AcceptNotify,
next: usize, next: usize,
backpressure: bool, backpressure: bool,
} }
const DELTA: usize = 100;
const CMD: mio::Token = mio::Token(0);
const TIMER: mio::Token = mio::Token(1);
const NOTIFY: mio::Token = mio::Token(2);
/// This function defines errors that are per-connection. Which basically /// This function defines errors that are per-connection. Which basically
/// means that if we get this error from `accept()` system call it means /// means that if we get this error from `accept()` system call it means
/// next connection might be ready to be accepted. /// next connection might be ready to be accepted.
@ -129,14 +117,13 @@ fn connection_error(e: &io::Error) -> bool {
} }
impl Accept { impl Accept {
#![allow(clippy::too_many_arguments)]
fn start( fn start(
rx: sync_mpsc::Receiver<Command>, rx: sync_mpsc::Receiver<Command>,
cmd_reg: mio::Registration, poll: mio::Poll,
notify_reg: mio::Registration, socks: Vec<(Token, Listener)>,
socks: Vec<(Token, StdListener)>,
srv: Server, srv: Server,
workers: Vec<WorkerClient>, workers: Vec<WorkerClient>,
notify: AcceptNotify,
) { ) {
let sys = System::current(); let sys = System::current();
@ -145,87 +132,50 @@ impl Accept {
.name("ntex-server accept loop".to_owned()) .name("ntex-server accept loop".to_owned())
.spawn(move || { .spawn(move || {
System::set_current(sys); System::set_current(sys);
let mut accept = Accept::new(rx, socks, workers, srv); Accept::new(rx, poll, socks, workers, srv, notify).poll()
// Start listening for incoming commands
if let Err(err) = accept.poll.register(
&cmd_reg,
CMD,
mio::Ready::readable(),
mio::PollOpt::edge(),
) {
panic!("Can not register Registration: {}", err);
}
// Start listening for notify updates
if let Err(err) = accept.poll.register(
&notify_reg,
NOTIFY,
mio::Ready::readable(),
mio::PollOpt::edge(),
) {
panic!("Can not register Registration: {}", err);
}
accept.poll();
}); });
} }
fn new( fn new(
rx: sync_mpsc::Receiver<Command>, rx: sync_mpsc::Receiver<Command>,
socks: Vec<(Token, StdListener)>, poll: mio::Poll,
socks: Vec<(Token, Listener)>,
workers: Vec<WorkerClient>, workers: Vec<WorkerClient>,
srv: Server, srv: Server,
notify: AcceptNotify,
) -> Accept { ) -> Accept {
// Create a poll instance
let poll = match mio::Poll::new() {
Ok(poll) => poll,
Err(err) => panic!("Can not create mio::Poll: {}", err),
};
// Start accept // Start accept
let mut sockets = Slab::new(); let mut sockets = Slab::new();
for (hnd_token, lst) in socks.into_iter() { for (hnd_token, mut lst) in socks.into_iter() {
let addr = lst.local_addr(); let addr = lst.local_addr();
let server = lst.into_listener();
let entry = sockets.vacant_entry(); let entry = sockets.vacant_entry();
let token = entry.key(); let token = entry.key();
// Start listening for incoming connections // Start listening for incoming connections
if let Err(err) = poll.register( if let Err(err) = poll.registry().register(
&server, &mut lst,
mio::Token(token + DELTA), mio::Token(token + DELTA),
mio::Ready::readable(), mio::Interest::READABLE,
mio::PollOpt::edge(),
) { ) {
panic!("Can not register io: {}", err); panic!("Can not register io: {}", err);
} }
entry.insert(ServerSocketInfo { entry.insert(ServerSocketInfo {
addr, addr,
sock: lst,
token: hnd_token, token: hnd_token,
sock: server,
timeout: None, timeout: None,
}); });
} }
// Timer
let (tm, tmr) = mio::Registration::new2();
if let Err(err) =
poll.register(&tm, TIMER, mio::Ready::readable(), mio::PollOpt::edge())
{
panic!("Can not register Registration: {}", err);
}
Accept { Accept {
poll, poll,
rx, rx,
sockets, sockets,
workers, workers,
notify,
srv, srv,
next: 0, next: 0,
timer: (tm, tmr),
backpressure: false, backpressure: false,
} }
} }
@ -251,13 +201,11 @@ impl Accept {
for event in events.iter() { for event in events.iter() {
let token = event.token(); let token = event.token();
match token { match token {
CMD => { NOTIFY => {
if !self.process_cmd() { if !self.process_cmd() {
return; return;
} }
} }
TIMER => self.process_timer(),
NOTIFY => self.backpressure(false),
_ => { _ => {
let token = usize::from(token); let token = usize::from(token);
if token < DELTA { if token < DELTA {
@ -275,11 +223,10 @@ impl Accept {
for (token, info) in self.sockets.iter_mut() { for (token, info) in self.sockets.iter_mut() {
if let Some(inst) = info.timeout.take() { if let Some(inst) = info.timeout.take() {
if now > inst { if now > inst {
if let Err(err) = self.poll.register( if let Err(err) = self.poll.registry().register(
&info.sock, &mut info.sock,
mio::Token(token + DELTA), mio::Token(token + DELTA),
mio::Ready::readable(), mio::Interest::READABLE,
mio::PollOpt::edge(),
) { ) {
error!("Can not register server socket {}", err); error!("Can not register server socket {}", err);
} else { } else {
@ -298,7 +245,9 @@ impl Accept {
Ok(cmd) => match cmd { Ok(cmd) => match cmd {
Command::Pause => { Command::Pause => {
for (_, info) in self.sockets.iter_mut() { for (_, info) in self.sockets.iter_mut() {
if let Err(err) = self.poll.deregister(&info.sock) { if let Err(err) =
self.poll.registry().deregister(&mut info.sock)
{
error!("Can not deregister server socket {}", err); error!("Can not deregister server socket {}", err);
} else { } else {
info!("Paused accepting connections on {}", info.addr); info!("Paused accepting connections on {}", info.addr);
@ -306,12 +255,11 @@ impl Accept {
} }
} }
Command::Resume => { Command::Resume => {
for (token, info) in self.sockets.iter() { for (token, info) in self.sockets.iter_mut() {
if let Err(err) = self.poll.register( if let Err(err) = self.poll.registry().register(
&info.sock, &mut info.sock,
mio::Token(token + DELTA), mio::Token(token + DELTA),
mio::Ready::readable(), mio::Interest::READABLE,
mio::PollOpt::edge(),
) { ) {
error!("Can not resume socket accept process: {}", err); error!("Can not resume socket accept process: {}", err);
} else { } else {
@ -323,9 +271,9 @@ impl Accept {
} }
} }
Command::Stop => { Command::Stop => {
for (_, info) in self.sockets.iter() { for (_, info) in self.sockets.iter_mut() {
trace!("Stopping socket listener: {}", info.addr); trace!("Stopping socket listener: {}", info.addr);
let _ = self.poll.deregister(&info.sock); let _ = self.poll.registry().deregister(&mut info.sock);
} }
return false; return false;
} }
@ -333,12 +281,18 @@ impl Accept {
self.backpressure(false); self.backpressure(false);
self.workers.push(worker); self.workers.push(worker);
} }
Command::Timer => {
self.process_timer();
}
Command::WorkerAvailable => {
self.backpressure(false);
}
}, },
Err(err) => match err { Err(err) => match err {
sync_mpsc::TryRecvError::Empty => break, sync_mpsc::TryRecvError::Empty => break,
sync_mpsc::TryRecvError::Disconnected => { sync_mpsc::TryRecvError::Disconnected => {
for (_, info) in self.sockets.iter() { for (_, info) in self.sockets.iter_mut() {
let _ = self.poll.deregister(&info.sock); let _ = self.poll.registry().deregister(&mut info.sock);
} }
return false; return false;
} }
@ -352,16 +306,15 @@ impl Accept {
if self.backpressure { if self.backpressure {
if !on { if !on {
self.backpressure = false; self.backpressure = false;
for (token, info) in self.sockets.iter() { for (token, info) in self.sockets.iter_mut() {
if info.timeout.is_some() { if info.timeout.is_some() {
// socket will re-register itself after timeout // socket will re-register itself after timeout
continue; continue;
} }
if let Err(err) = self.poll.register( if let Err(err) = self.poll.registry().register(
&info.sock, &mut info.sock,
mio::Token(token + DELTA), mio::Token(token + DELTA),
mio::Ready::readable(), mio::Interest::READABLE,
mio::PollOpt::edge(),
) { ) {
error!("Can not resume socket accept process: {}", err); error!("Can not resume socket accept process: {}", err);
} else { } else {
@ -371,10 +324,10 @@ impl Accept {
} }
} else if on { } else if on {
self.backpressure = true; self.backpressure = true;
for (_, info) in self.sockets.iter() { for (_, info) in self.sockets.iter_mut() {
if info.timeout.is_none() { if info.timeout.is_none() {
trace!("Enabling backpressure for {}", info.addr); trace!("Enabling backpressure for {}", info.addr);
let _ = self.poll.deregister(&info.sock); let _ = self.poll.registry().deregister(&mut info.sock);
} }
} }
} }
@ -452,18 +405,19 @@ impl Accept {
Err(ref e) if connection_error(e) => continue, Err(ref e) if connection_error(e) => continue,
Err(e) => { Err(e) => {
error!("Error accepting connection: {}", e); error!("Error accepting connection: {}", e);
if let Err(err) = self.poll.deregister(&info.sock) { if let Err(err) = self.poll.registry().deregister(&mut info.sock)
{
error!("Can not deregister server socket {}", err); error!("Can not deregister server socket {}", err);
} }
// sleep after error // sleep after error
info.timeout = Some(Instant::now() + Duration::from_millis(500)); info.timeout = Some(Instant::now() + Duration::from_millis(500));
let r = self.timer.1.clone(); let notify = self.notify.clone();
System::current().arbiter().send(Box::pin(async move { System::current().arbiter().spawn(Box::pin(async move {
delay_until(Instant::now() + Duration::from_millis(510)) delay_until(Instant::now() + Duration::from_millis(510))
.await; .await;
let _ = r.set_readiness(mio::Ready::readable()); notify.send(Command::Timer);
})); }));
return; return;
} }

View file

@ -19,7 +19,7 @@ use super::accept::{AcceptLoop, AcceptNotify, Command};
use super::config::{ConfiguredService, ServiceConfig}; use super::config::{ConfiguredService, ServiceConfig};
use super::service::{Factory, InternalServiceFactory, StreamServiceFactory}; use super::service::{Factory, InternalServiceFactory, StreamServiceFactory};
use super::signals::{Signal, Signals}; use super::signals::{Signal, Signals};
use super::socket::StdListener; use super::socket::Listener;
use super::worker::{self, Worker, WorkerAvailability, WorkerClient}; use super::worker::{self, Worker, WorkerAvailability, WorkerClient};
use super::{Server, ServerCommand, Token}; use super::{Server, ServerCommand, Token};
@ -30,7 +30,7 @@ pub struct ServerBuilder {
backlog: i32, backlog: i32,
workers: Vec<(usize, WorkerClient)>, workers: Vec<(usize, WorkerClient)>,
services: Vec<Box<dyn InternalServiceFactory>>, services: Vec<Box<dyn InternalServiceFactory>>,
sockets: Vec<(Token, String, StdListener)>, sockets: Vec<(Token, String, Listener)>,
accept: AcceptLoop, accept: AcceptLoop,
exit: bool, exit: bool,
shutdown_timeout: Duration, shutdown_timeout: Duration,
@ -150,7 +150,7 @@ impl ServerBuilder {
for (name, lst) in cfg.services { for (name, lst) in cfg.services {
let token = self.token.next(); let token = self.token.next();
srv.stream(token, name.clone(), lst.local_addr()?); srv.stream(token, name.clone(), lst.local_addr()?);
self.sockets.push((token, name, StdListener::Tcp(lst))); self.sockets.push((token, name, Listener::from_tcp(lst)));
} }
self.services.push(Box::new(srv)); self.services.push(Box::new(srv));
} }
@ -180,8 +180,11 @@ impl ServerBuilder {
factory.clone(), factory.clone(),
lst.local_addr()?, lst.local_addr()?,
)); ));
self.sockets self.sockets.push((
.push((token, name.as_ref().to_string(), StdListener::Tcp(lst))); token,
name.as_ref().to_string(),
Listener::from_tcp(lst),
));
} }
Ok(self) Ok(self)
} }
@ -232,7 +235,7 @@ impl ServerBuilder {
addr, addr,
)); ));
self.sockets self.sockets
.push((token, name.as_ref().to_string(), StdListener::Uds(lst))); .push((token, name.as_ref().to_string(), Listener::from_uds(lst)));
Ok(self) Ok(self)
} }
@ -254,7 +257,7 @@ impl ServerBuilder {
lst.local_addr()?, lst.local_addr()?,
)); ));
self.sockets self.sockets
.push((token, name.as_ref().to_string(), StdListener::Tcp(lst))); .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
Ok(self) Ok(self)
} }
@ -273,7 +276,7 @@ impl ServerBuilder {
// start workers // start workers
let mut workers = Vec::new(); let mut workers = Vec::new();
for idx in 0..self.threads { for idx in 0..self.threads {
let worker = self.start_worker(idx, self.accept.get_notify()); let worker = self.start_worker(idx, self.accept.notify());
workers.push(worker.clone()); workers.push(worker.clone());
self.workers.push((idx, worker)); self.workers.push((idx, worker));
} }
@ -438,7 +441,7 @@ impl ServerBuilder {
break; break;
} }
let worker = self.start_worker(new_idx, self.accept.get_notify()); let worker = self.start_worker(new_idx, self.accept.notify());
self.workers.push((new_idx, worker.clone())); self.workers.push((new_idx, worker.clone()));
self.accept.send(Command::Worker(worker)); self.accept.send(Command::Worker(worker));
} }
@ -509,10 +512,6 @@ pub(crate) fn create_tcp_listener(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use futures::future::ok;
use std::sync::mpsc;
use std::{net, thread, time};
use super::*; use super::*;
use crate::server::{signals, Server, TestServer}; use crate::server::{signals, Server, TestServer};
use crate::service::fn_service; use crate::service::fn_service;
@ -520,6 +519,10 @@ mod tests {
#[cfg(unix)] #[cfg(unix)]
#[ntex_rt::test] #[ntex_rt::test]
async fn test_signals() { async fn test_signals() {
use futures::future::ok;
use std::sync::mpsc;
use std::{net, thread, time};
fn start(tx: mpsc::Sender<(Server, net::SocketAddr)>) -> thread::JoinHandle<()> { fn start(tx: mpsc::Sender<(Server, net::SocketAddr)>) -> thread::JoinHandle<()> {
thread::spawn(move || { thread::spawn(move || {
let mut sys = crate::rt::System::new("test"); let mut sys = crate::rt::System::new("test");
@ -546,11 +549,11 @@ mod tests {
let h = start(tx); let h = start(tx);
let (srv, addr) = rx.recv().unwrap(); let (srv, addr) = rx.recv().unwrap();
thread::sleep(time::Duration::from_millis(300)); crate::rt::time::sleep(time::Duration::from_millis(300)).await;
assert!(net::TcpStream::connect(addr).is_ok()); assert!(net::TcpStream::connect(addr).is_ok());
srv.signal(*sig); srv.signal(*sig);
thread::sleep(time::Duration::from_millis(300)); crate::rt::time::sleep(time::Duration::from_millis(300)).await;
assert!(net::TcpStream::connect(addr).is_err()); assert!(net::TcpStream::connect(addr).is_err());
let _ = h.join(); let _ = h.join();
} }

View file

@ -6,7 +6,7 @@ use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use std::{fmt, io}; use std::{fmt, io};
pub use open_ssl::ssl::{AlpnError, SslAcceptor, SslAcceptorBuilder}; pub use open_ssl::ssl::{AlpnError, Ssl, SslAcceptor, SslAcceptorBuilder};
pub use tokio_openssl::SslStream; pub use tokio_openssl::SslStream;
use futures::future::{ok, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, FutureExt, LocalBoxFuture, Ready};
@ -107,20 +107,22 @@ where
#[inline] #[inline]
fn call(&self, req: Self::Request) -> Self::Future { fn call(&self, req: Self::Request) -> Self::Future {
let acc = self.acceptor.clone(); let ssl = Ssl::new(self.acceptor.context())
.expect("Provided SSL acceptor was invalid.");
AcceptorServiceResponse { AcceptorServiceResponse {
_guard: self.conns.get(), _guard: self.conns.get(),
delay: if self.timeout == ZERO { delay: if self.timeout == ZERO {
None None
} else { } else {
Some(delay_for(self.timeout)) Some(Box::pin(delay_for(self.timeout)))
}, },
fut: async move { fut: async move {
let acc = acc; let mut io = SslStream::new(ssl, req)?;
tokio_openssl::accept(&acc, req).await.map_err(|e| { Pin::new(&mut io).accept().await.map_err(|e| {
let e: Box<dyn Error> = Box::new(e); let e: Box<dyn Error> = Box::new(e);
e e
}) })?;
Ok(io)
} }
.boxed_local(), .boxed_local(),
} }
@ -132,7 +134,7 @@ where
T: AsyncRead + AsyncWrite, T: AsyncRead + AsyncWrite,
{ {
fut: LocalBoxFuture<'static, Result<SslStream<T>, Box<dyn Error>>>, fut: LocalBoxFuture<'static, Result<SslStream<T>, Box<dyn Error>>>,
delay: Option<Delay>, delay: Option<Pin<Box<Delay>>>,
_guard: CounterGuard, _guard: CounterGuard,
} }

View file

@ -1,13 +1,7 @@
use std::error::Error;
use std::future::Future;
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration; use std::{error::Error, io, marker::PhantomData, pin::Pin, sync::Arc, time::Duration};
use futures::future::{ok, Ready}; use futures::future::{ok, Future, Ready};
use tokio_rustls::{Accept, TlsAcceptor}; use tokio_rustls::{Accept, TlsAcceptor};
pub use rust_tls::{ServerConfig, Session}; pub use rust_tls::{ServerConfig, Session};
@ -15,7 +9,7 @@ pub use tokio_rustls::server::TlsStream;
pub use webpki_roots::TLS_SERVER_ROOTS; pub use webpki_roots::TLS_SERVER_ROOTS;
use crate::codec::{AsyncRead, AsyncWrite}; use crate::codec::{AsyncRead, AsyncWrite};
use crate::rt::time::{delay_for, Delay}; use crate::rt::time::{sleep, Sleep};
use crate::service::{Service, ServiceFactory}; use crate::service::{Service, ServiceFactory};
use crate::util::counter::{Counter, CounterGuard}; use crate::util::counter::{Counter, CounterGuard};
@ -112,7 +106,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Service for AcceptorService<T> {
delay: if self.timeout == ZERO { delay: if self.timeout == ZERO {
None None
} else { } else {
Some(delay_for(self.timeout)) Some(Box::pin(sleep(self.timeout)))
}, },
} }
} }
@ -123,7 +117,7 @@ where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
{ {
fut: Accept<T>, fut: Accept<T>,
delay: Option<Delay>, delay: Option<Pin<Box<Sleep>>>,
_guard: CounterGuard, _guard: CounterGuard,
} }

View file

@ -11,13 +11,13 @@ use crate::rt::spawn;
use crate::service::{Service, ServiceFactory}; use crate::service::{Service, ServiceFactory};
use crate::util::counter::CounterGuard; use crate::util::counter::CounterGuard;
use super::socket::{FromStream, StdStream}; use super::socket::{FromStream, Stream};
use super::Token; use super::Token;
/// Server message /// Server message
pub(super) enum ServerMessage { pub(super) enum ServerMessage {
/// New stream /// New stream
Connect(StdStream), Connect(Stream),
/// Gracefull shutdown /// Gracefull shutdown
Shutdown(Duration), Shutdown(Duration),
/// Force shutdown /// Force shutdown
@ -84,7 +84,7 @@ where
fn call(&self, (guard, req): (Option<CounterGuard>, ServerMessage)) -> Self::Future { fn call(&self, (guard, req): (Option<CounterGuard>, ServerMessage)) -> Self::Future {
match req { match req {
ServerMessage::Connect(stream) => { ServerMessage::Connect(stream) => {
let stream = FromStream::from_stdstream(stream).map_err(|e| { let stream = FromStream::from_stream(stream).map_err(|e| {
error!("Can not convert to an async tcp stream: {}", e); error!("Can not convert to an async tcp stream: {}", e);
}); });

View file

@ -1,8 +1,4 @@
use std::pin::Pin; use std::{future::Future, pin::Pin, task::Context, task::Poll};
use std::task::{Context, Poll};
use futures::future::{Future, FutureExt};
use futures::stream::{unfold, Stream, StreamExt};
use crate::server::Server; use crate::server::Server;
@ -22,38 +18,37 @@ pub(crate) enum Signal {
pub(super) struct Signals { pub(super) struct Signals {
srv: Server, srv: Server,
streams: Vec<(Signal, Pin<Box<dyn Stream<Item = ()>>>)>, #[cfg(not(unix))]
signal: Pin<Box<dyn Future<Output = std::io::Result<()>>>>,
#[cfg(unix)]
signals: Vec<(Signal, crate::rt::signal::unix::Signal)>,
} }
impl Signals { impl Signals {
pub(super) fn new(srv: Server) -> Signals { pub(super) fn new(srv: Server) -> Signals {
let mut signals = Signals { #[cfg(not(unix))]
{
Signals {
srv, srv,
streams: vec![( signal: Box::pin(crate::rt::signal::ctrl_c()),
Signal::Int, }
unfold((), |_| { }
crate::rt::signal::ctrl_c().map(|res| match res {
Ok(_) => Some(((), ())),
Err(_) => None,
})
})
.boxed_local(),
)],
};
#[cfg(unix)] #[cfg(unix)]
{ {
use crate::rt::signal::unix; use crate::rt::signal::unix;
let sig_map = [ let sig_map = [
(unix::SignalKind::interrupt(), Signal::Int),
(unix::SignalKind::hangup(), Signal::Hup), (unix::SignalKind::hangup(), Signal::Hup),
(unix::SignalKind::terminate(), Signal::Term), (unix::SignalKind::terminate(), Signal::Term),
(unix::SignalKind::quit(), Signal::Quit), (unix::SignalKind::quit(), Signal::Quit),
]; ];
let mut signals = Vec::new();
for (kind, sig) in sig_map.iter() { for (kind, sig) in sig_map.iter() {
match unix::signal(*kind) { match unix::signal(*kind) {
Ok(stream) => signals.streams.push((*sig, stream.boxed_local())), Ok(stream) => signals.push((*sig, stream)),
Err(e) => log::error!( Err(e) => log::error!(
"Can not initialize stream handler for {:?} err: {}", "Can not initialize stream handler for {:?} err: {}",
sig, sig,
@ -61,9 +56,9 @@ impl Signals {
), ),
} }
} }
}
signals Signals { srv, signals }
}
} }
} }
@ -71,42 +66,26 @@ impl Future for Signals {
type Output = (); type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
for idx in 0..self.streams.len() { #[cfg(not(unix))]
loop { match self.signal.as_mut().poll(cx) {
match Pin::new(&mut self.streams[idx].1).poll_next(cx) { Poll::Ready(_) => {
Poll::Ready(None) => return Poll::Ready(()), self.srv.signal(Signal::Int);
Poll::Pending => break, Poll::Ready(())
Poll::Ready(Some(_)) => { }
let sig = self.streams[idx].0; Poll::Pending => Poll::Pending,
}
#[cfg(unix)]
{
let mut sigs = Vec::new();
for (sig, fut) in self.signals.iter_mut() {
if Pin::new(fut).poll_recv(cx).is_ready() {
sigs.push(*sig)
}
}
for sig in sigs {
self.srv.signal(sig); self.srv.signal(sig);
} }
}
}
}
Poll::Pending Poll::Pending
} }
} }
#[cfg(test)]
mod tests {
use futures::channel::mpsc;
use futures::future::{lazy, ready};
use futures::stream::once;
use super::*;
use crate::server::ServerCommand;
#[ntex_rt::test]
async fn signals() {
let (tx, mut rx) = mpsc::unbounded();
let server = Server::new(tx);
let mut signals = Signals::new(server);
signals.streams = vec![(Signal::Int, once(ready(())).boxed_local())];
let _ = lazy(|cx| Pin::new(&mut signals).poll(cx)).await;
if let Some(ServerCommand::Signal(sig)) = rx.next().await {
assert_eq!(sig, Signal::Int);
}
}
} }

View file

@ -3,23 +3,23 @@ use std::{fmt, io, net};
use crate::codec::{AsyncRead, AsyncWrite}; use crate::codec::{AsyncRead, AsyncWrite};
use crate::rt::net::TcpStream; use crate::rt::net::TcpStream;
pub(crate) enum StdListener { pub(crate) enum Listener {
Tcp(net::TcpListener), Tcp(mio::net::TcpListener),
#[cfg(all(unix))] #[cfg(unix)]
Uds(std::os::unix::net::UnixListener), Uds(mio::net::UnixListener),
} }
pub(crate) enum SocketAddr { pub(crate) enum SocketAddr {
Tcp(net::SocketAddr), Tcp(net::SocketAddr),
#[cfg(all(unix))] #[cfg(unix)]
Uds(std::os::unix::net::SocketAddr), Uds(mio::net::SocketAddr),
} }
impl fmt::Display for SocketAddr { impl fmt::Display for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self { match *self {
SocketAddr::Tcp(ref addr) => write!(f, "{}", addr), SocketAddr::Tcp(ref addr) => write!(f, "{}", addr),
#[cfg(all(unix))] #[cfg(unix)]
SocketAddr::Uds(ref addr) => write!(f, "{:?}", addr), SocketAddr::Uds(ref addr) => write!(f, "{:?}", addr),
} }
} }
@ -29,118 +29,102 @@ impl fmt::Debug for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self { match *self {
SocketAddr::Tcp(ref addr) => write!(f, "{:?}", addr), SocketAddr::Tcp(ref addr) => write!(f, "{:?}", addr),
#[cfg(all(unix))] #[cfg(unix)]
SocketAddr::Uds(ref addr) => write!(f, "{:?}", addr), SocketAddr::Uds(ref addr) => write!(f, "{:?}", addr),
} }
} }
} }
impl fmt::Debug for StdListener { impl fmt::Debug for Listener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self { match *self {
StdListener::Tcp(ref lst) => write!(f, "{:?}", lst), Listener::Tcp(ref lst) => write!(f, "{:?}", lst),
#[cfg(all(unix))] #[cfg(unix)]
StdListener::Uds(ref lst) => write!(f, "{:?}", lst), Listener::Uds(ref lst) => write!(f, "{:?}", lst),
} }
} }
} }
impl fmt::Display for StdListener { impl fmt::Display for Listener {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self { match *self {
StdListener::Tcp(ref lst) => write!(f, "{}", lst.local_addr().ok().unwrap()), Listener::Tcp(ref lst) => write!(f, "{}", lst.local_addr().ok().unwrap()),
#[cfg(all(unix))] #[cfg(unix)]
StdListener::Uds(ref lst) => { Listener::Uds(ref lst) => {
write!(f, "{:?}", lst.local_addr().ok().unwrap()) write!(f, "{:?}", lst.local_addr().ok().unwrap())
} }
} }
} }
} }
impl StdListener { impl Listener {
pub(super) fn from_tcp(lst: net::TcpListener) -> Self {
let _ = lst.set_nonblocking(true);
Listener::Tcp(mio::net::TcpListener::from_std(lst))
}
#[cfg(unix)]
pub(super) fn from_uds(lst: std::os::unix::net::UnixListener) -> Self {
let _ = lst.set_nonblocking(true);
Listener::Uds(mio::net::UnixListener::from_std(lst))
}
pub(crate) fn local_addr(&self) -> SocketAddr { pub(crate) fn local_addr(&self) -> SocketAddr {
match self { match self {
StdListener::Tcp(lst) => SocketAddr::Tcp(lst.local_addr().unwrap()), Listener::Tcp(lst) => SocketAddr::Tcp(lst.local_addr().unwrap()),
#[cfg(all(unix))] #[cfg(unix)]
StdListener::Uds(lst) => SocketAddr::Uds(lst.local_addr().unwrap()), Listener::Uds(lst) => SocketAddr::Uds(lst.local_addr().unwrap()),
} }
} }
pub(crate) fn into_listener(self) -> SocketListener { pub(crate) fn accept(&self) -> io::Result<Option<(Stream, SocketAddr)>> {
match self {
StdListener::Tcp(lst) => SocketListener::Tcp(
mio::net::TcpListener::from_std(lst)
.expect("Can not create mio::net::TcpListener"),
),
#[cfg(all(unix))]
StdListener::Uds(lst) => SocketListener::Uds(
mio_uds::UnixListener::from_listener(lst)
.expect("Can not create mio_uds::UnixListener"),
),
}
}
}
#[derive(Debug)]
pub enum StdStream {
Tcp(std::net::TcpStream),
#[cfg(all(unix))]
Uds(std::os::unix::net::UnixStream),
}
pub(crate) enum SocketListener {
Tcp(mio::net::TcpListener),
#[cfg(all(unix))]
Uds(mio_uds::UnixListener),
}
impl SocketListener {
pub(crate) fn accept(&self) -> io::Result<Option<(StdStream, SocketAddr)>> {
match *self { match *self {
SocketListener::Tcp(ref lst) => lst.accept_std().map(|(stream, addr)| { Listener::Tcp(ref lst) => lst.accept().map(|(stream, addr)| {
Some((StdStream::Tcp(stream), SocketAddr::Tcp(addr))) Some((Stream::Tcp(stream), SocketAddr::Tcp(addr)))
}), }),
#[cfg(all(unix))] #[cfg(unix)]
SocketListener::Uds(ref lst) => lst.accept_std().map(|res| { Listener::Uds(ref lst) => lst.accept().map(|(stream, addr)| {
res.map(|(stream, addr)| (StdStream::Uds(stream), SocketAddr::Uds(addr))) Some((Stream::Uds(stream), SocketAddr::Uds(addr)))
}), }),
} }
} }
} }
impl mio::Evented for SocketListener { impl mio::event::Source for Listener {
#[inline]
fn register( fn register(
&self, &mut self,
poll: &mio::Poll, poll: &mio::Registry,
token: mio::Token, token: mio::Token,
interest: mio::Ready, interest: mio::Interest,
opts: mio::PollOpt,
) -> io::Result<()> { ) -> io::Result<()> {
match *self { match *self {
SocketListener::Tcp(ref lst) => lst.register(poll, token, interest, opts), Listener::Tcp(ref mut lst) => lst.register(poll, token, interest),
#[cfg(all(unix))] #[cfg(unix)]
SocketListener::Uds(ref lst) => lst.register(poll, token, interest, opts), Listener::Uds(ref mut lst) => lst.register(poll, token, interest),
} }
} }
#[inline]
fn reregister( fn reregister(
&self, &mut self,
poll: &mio::Poll, poll: &mio::Registry,
token: mio::Token, token: mio::Token,
interest: mio::Ready, interest: mio::Interest,
opts: mio::PollOpt,
) -> io::Result<()> { ) -> io::Result<()> {
match *self { match *self {
SocketListener::Tcp(ref lst) => lst.reregister(poll, token, interest, opts), Listener::Tcp(ref mut lst) => lst.reregister(poll, token, interest),
#[cfg(all(unix))] #[cfg(unix)]
SocketListener::Uds(ref lst) => lst.reregister(poll, token, interest, opts), Listener::Uds(ref mut lst) => lst.reregister(poll, token, interest),
} }
} }
fn deregister(&self, poll: &mio::Poll) -> io::Result<()> {
#[inline]
fn deregister(&mut self, poll: &mio::Registry) -> io::Result<()> {
match *self { match *self {
SocketListener::Tcp(ref lst) => lst.deregister(poll), Listener::Tcp(ref mut lst) => lst.deregister(poll),
#[cfg(all(unix))] #[cfg(unix)]
SocketListener::Uds(ref lst) => { Listener::Uds(ref mut lst) => {
let res = lst.deregister(poll); let res = lst.deregister(poll);
// cleanup file path // cleanup file path
@ -155,28 +139,69 @@ impl mio::Evented for SocketListener {
} }
} }
pub trait FromStream: AsyncRead + AsyncWrite + Sized { #[derive(Debug)]
fn from_stdstream(sock: StdStream) -> io::Result<Self>; pub enum Stream {
Tcp(mio::net::TcpStream),
#[cfg(unix)]
Uds(mio::net::UnixStream),
} }
pub trait FromStream: AsyncRead + AsyncWrite + Sized {
fn from_stream(stream: Stream) -> io::Result<Self>;
}
#[cfg(unix)]
impl FromStream for TcpStream { impl FromStream for TcpStream {
fn from_stdstream(sock: StdStream) -> io::Result<Self> { fn from_stream(sock: Stream) -> io::Result<Self> {
match sock { match sock {
StdStream::Tcp(stream) => TcpStream::from_std(stream), Stream::Tcp(stream) => {
#[cfg(all(unix))] use std::os::unix::io::{FromRawFd, IntoRawFd};
StdStream::Uds(_) => { let fd = IntoRawFd::into_raw_fd(stream);
let sock: std::net::TcpStream = unsafe { FromRawFd::from_raw_fd(fd) };
let _ = sock.set_nonblocking(true);
TcpStream::from_std(sock)
}
#[cfg(unix)]
Stream::Uds(_) => {
panic!("Should not happen, bug in server impl"); panic!("Should not happen, bug in server impl");
} }
} }
} }
} }
#[cfg(all(unix))] #[cfg(windows)]
impl FromStream for crate::rt::net::UnixStream { impl FromStream for TcpStream {
fn from_stdstream(sock: StdStream) -> io::Result<Self> { fn from_stream(sock: Stream) -> io::Result<Self> {
match sock { match sock {
StdStream::Tcp(_) => panic!("Should not happen, bug in server impl"), Stream::Tcp(stream) => {
StdStream::Uds(stream) => crate::rt::net::UnixStream::from_std(stream), use std::os::windows::io::{FromRawSocket, IntoRawSocket};
let fd = IntoRawSocket::into_raw_socket(stream);
let sock: std::net::TcpStream =
unsafe { FromRawSocket::from_raw_socket(fd) };
let _ = sock.set_nonblocking(true);
TcpStream::from_std(sock)
}
#[cfg(unix)]
Stream::Uds(_) => {
panic!("Should not happen, bug in server impl");
}
}
}
}
#[cfg(unix)]
impl FromStream for crate::rt::net::UnixStream {
fn from_stream(sock: Stream) -> io::Result<Self> {
match sock {
Stream::Tcp(_) => panic!("Should not happen, bug in server impl"),
Stream::Uds(stream) => {
use std::os::unix::io::{FromRawFd, IntoRawFd};
let fd = IntoRawFd::into_raw_fd(stream);
let sock: std::os::unix::net::UnixStream =
unsafe { FromRawFd::from_raw_fd(fd) };
let _ = sock.set_nonblocking(true);
crate::rt::net::UnixStream::from_std(sock)
}
} }
} }
} }
@ -198,7 +223,7 @@ mod tests {
socket.set_reuse_address(true).unwrap(); socket.set_reuse_address(true).unwrap();
socket.bind(&SockAddr::from(addr)).unwrap(); socket.bind(&SockAddr::from(addr)).unwrap();
let tcp = socket.into_tcp_listener(); let tcp = socket.into_tcp_listener();
let lst = StdListener::Tcp(tcp); let lst = Listener::Tcp(mio::net::TcpListener::from_std(tcp));
assert!(format!("{:?}", lst).contains("TcpListener")); assert!(format!("{:?}", lst).contains("TcpListener"));
assert!(format!("{}", lst).contains("127.0.0.1")); assert!(format!("{}", lst).contains("127.0.0.1"));
} }
@ -209,13 +234,14 @@ mod tests {
use std::os::unix::net::UnixListener; use std::os::unix::net::UnixListener;
let _ = std::fs::remove_file("/tmp/sock.xxxxx"); let _ = std::fs::remove_file("/tmp/sock.xxxxx");
if let Ok(socket) = UnixListener::bind("/tmp/sock.xxxxx") { if let Ok(lst) = UnixListener::bind("/tmp/sock.xxxxx") {
let addr = socket.local_addr().expect("Couldn't get local address"); let lst = mio::net::UnixListener::from_std(lst);
let addr = lst.local_addr().expect("Couldn't get local address");
let a = SocketAddr::Uds(addr); let a = SocketAddr::Uds(addr);
assert!(format!("{:?}", a).contains("/tmp/sock.xxxxx")); assert!(format!("{:?}", a).contains("/tmp/sock.xxxxx"));
assert!(format!("{}", a).contains("/tmp/sock.xxxxx")); assert!(format!("{}", a).contains("/tmp/sock.xxxxx"));
let lst = StdListener::Uds(socket); let lst = Listener::Uds(lst);
assert!(format!("{:?}", lst).contains("/tmp/sock.xxxxx")); assert!(format!("{:?}", lst).contains("/tmp/sock.xxxxx"));
assert!(format!("{}", lst).contains("/tmp/sock.xxxxx")); assert!(format!("{}", lst).contains("/tmp/sock.xxxxx"));
} }

View file

@ -1,25 +1,25 @@
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time; use std::{pin::Pin, sync::Arc, time};
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}; use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::channel::oneshot; use futures::channel::oneshot;
use futures::future::{join_all, LocalBoxFuture, MapOk}; use futures::future::{join_all, LocalBoxFuture, MapOk};
use futures::{Future, FutureExt, Stream, TryFutureExt}; use futures::{Future, FutureExt, Stream as StdStream, TryFutureExt};
use crate::rt::time::{delay_until, Delay, Instant}; use crate::rt::time::{delay_until, Delay, Instant};
use crate::rt::{spawn, Arbiter}; use crate::rt::{spawn, Arbiter};
use crate::util::counter::Counter; use crate::util::counter::Counter;
use super::accept::AcceptNotify; use super::accept::{AcceptNotify, Command};
use super::service::{BoxedServerService, InternalServiceFactory, ServerMessage}; use super::service::{BoxedServerService, InternalServiceFactory, ServerMessage};
use super::socket::{SocketAddr, StdStream}; use super::socket::{SocketAddr, Stream};
use super::Token; use super::Token;
#[derive(Debug)]
pub(super) struct WorkerCommand(Conn); pub(super) struct WorkerCommand(Conn);
#[derive(Debug)]
/// Stop worker message. Returns `true` on successful shutdown /// Stop worker message. Returns `true` on successful shutdown
/// and `false` if some connections are still alive. /// and `false` if some connections are still alive.
pub(super) struct StopCommand { pub(super) struct StopCommand {
@ -29,7 +29,7 @@ pub(super) struct StopCommand {
#[derive(Debug)] #[derive(Debug)]
pub(super) struct Conn { pub(super) struct Conn {
pub(super) io: StdStream, pub(super) io: Stream,
pub(super) token: Token, pub(super) token: Token,
pub(super) peer: Option<SocketAddr>, pub(super) peer: Option<SocketAddr>,
} }
@ -55,7 +55,7 @@ thread_local! {
Counter::new(MAX_CONNS.load(Ordering::Relaxed)); Counter::new(MAX_CONNS.load(Ordering::Relaxed));
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub(super) struct WorkerClient { pub(super) struct WorkerClient {
pub(super) idx: usize, pub(super) idx: usize,
tx1: UnboundedSender<WorkerCommand>, tx1: UnboundedSender<WorkerCommand>,
@ -95,7 +95,7 @@ impl WorkerClient {
} }
} }
#[derive(Clone)] #[derive(Debug, Clone)]
pub(super) struct WorkerAvailability { pub(super) struct WorkerAvailability {
notify: AcceptNotify, notify: AcceptNotify,
available: Arc<AtomicBool>, available: Arc<AtomicBool>,
@ -116,7 +116,7 @@ impl WorkerAvailability {
pub(super) fn set(&self, val: bool) { pub(super) fn set(&self, val: bool) {
let old = self.available.swap(val, Ordering::Release); let old = self.available.swap(val, Ordering::Release);
if !old && val { if !old && val {
self.notify.notify() self.notify.send(Command::WorkerAvailable)
} }
} }
} }
@ -578,7 +578,11 @@ mod tests {
async fn basics() { async fn basics() {
let (_tx1, rx1) = unbounded(); let (_tx1, rx1) = unbounded();
let (mut tx2, rx2) = unbounded(); let (mut tx2, rx2) = unbounded();
let avail = WorkerAvailability::new(AcceptNotify::default()); let (sync_tx, _sync_rx) = std::sync::mpsc::channel();
let poll = mio::Poll::new().unwrap();
let waker = Arc::new(mio::Waker::new(poll.registry(), mio::Token(1)).unwrap());
let avail =
WorkerAvailability::new(AcceptNotify::new(waker.clone(), sync_tx.clone()));
let st = Arc::new(Mutex::new(St::Pending)); let st = Arc::new(Mutex::new(St::Pending));
let counter = Arc::new(Mutex::new(0)); let counter = Arc::new(Mutex::new(0));
@ -655,7 +659,7 @@ mod tests {
// force shutdown // force shutdown
let (_tx1, rx1) = unbounded(); let (_tx1, rx1) = unbounded();
let (mut tx2, rx2) = unbounded(); let (mut tx2, rx2) = unbounded();
let avail = WorkerAvailability::new(AcceptNotify::default()); let avail = WorkerAvailability::new(AcceptNotify::new(waker, sync_tx.clone()));
let f = SrvFactory { let f = SrvFactory {
st: st.clone(), st: st.clone(),
counter: counter.clone(), counter: counter.clone(),

View file

@ -8,7 +8,7 @@ use bytes::BytesMut;
use futures::future::poll_fn; use futures::future::poll_fn;
use futures::task::AtomicWaker; use futures::task::AtomicWaker;
use crate::codec::{AsyncRead, AsyncWrite}; use crate::codec::{AsyncRead, AsyncWrite, ReadBuf};
use crate::rt::time::delay_for; use crate::rt::time::delay_for;
/// Async io stream /// Async io stream
@ -244,24 +244,24 @@ impl AsyncRead for Io {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut [u8], buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<()>> {
let guard = self.local.lock().unwrap(); let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut(); let mut ch = guard.borrow_mut();
ch.waker.register(cx.waker()); ch.waker.register(cx.waker());
if !ch.buf.is_empty() { if !ch.buf.is_empty() {
let size = std::cmp::min(ch.buf.len(), buf.len()); let size = std::cmp::min(ch.buf.len(), buf.capacity());
let b = ch.buf.split_to(size); let b = ch.buf.split_to(size);
buf[..size].copy_from_slice(&b); buf.put_slice(&b);
return Poll::Ready(Ok(size)); return Poll::Ready(Ok(()));
} }
match mem::take(&mut ch.read) { match mem::take(&mut ch.read) {
IoState::Ok => Poll::Pending, IoState::Ok => Poll::Pending,
IoState::Close => { IoState::Close => {
ch.read = IoState::Close; ch.read = IoState::Close;
Poll::Ready(Ok(0)) Poll::Ready(Ok(()))
} }
IoState::Pending => Poll::Pending, IoState::Pending => Poll::Pending,
IoState::Err(e) => Poll::Ready(Err(e)), IoState::Err(e) => Poll::Ready(Err(e)),

View file

@ -85,7 +85,7 @@ pub struct KeepAliveService<R, E, F> {
} }
struct Inner { struct Inner {
delay: Delay, delay: Pin<Box<Delay>>,
expire: Instant, expire: Instant,
} }
@ -101,7 +101,7 @@ where
time, time,
inner: RefCell::new(Inner { inner: RefCell::new(Inner {
expire, expire,
delay: delay_until(expire), delay: Box::pin(delay_until(expire)),
}), }),
_t: PhantomData, _t: PhantomData,
} }
@ -127,7 +127,7 @@ where
Poll::Ready(Err((self.f)())) Poll::Ready(Err((self.f)()))
} else { } else {
let expire = inner.expire; let expire = inner.expire;
inner.delay.reset(expire); inner.delay.as_mut().reset(expire);
let _ = Pin::new(&mut inner.delay).poll(cx); let _ = Pin::new(&mut inner.delay).poll(cx);
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }

View file

@ -154,7 +154,7 @@ where
} else { } else {
Either::Left(TimeoutServiceResponse { Either::Left(TimeoutServiceResponse {
fut: self.service.call(request), fut: self.service.call(request),
sleep: delay_for(self.timeout), sleep: Box::pin(delay_for(self.timeout)),
}) })
} }
} }
@ -167,7 +167,7 @@ pin_project_lite::pin_project! {
pub struct TimeoutServiceResponse<T: Service> { pub struct TimeoutServiceResponse<T: Service> {
#[pin] #[pin]
fut: T::Future, fut: T::Future,
sleep: Delay, sleep: Pin<Box<Delay>>,
} }
} }

View file

@ -113,7 +113,7 @@ impl WebResponseError<DefaultError> for crate::connect::openssl::SslError {}
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
/// `InternalServerError` for `openssl::ssl::HandshakeError` /// `InternalServerError` for `openssl::ssl::HandshakeError`
impl<T: fmt::Debug + 'static> WebResponseError<DefaultError> impl<T: fmt::Debug + 'static> WebResponseError<DefaultError>
for tokio_openssl::HandshakeError<T> for open_ssl::ssl::HandshakeError<T>
{ {
} }

View file

@ -18,6 +18,7 @@ use ntex::http::header::{
TRANSFER_ENCODING, TRANSFER_ENCODING,
}; };
use ntex::http::{Method, StatusCode}; use ntex::http::{Method, StatusCode};
use ntex::rt::time::{sleep, Sleep};
use ntex::web::middleware::Compress; use ntex::web::middleware::Compress;
use ntex::web::{ use ntex::web::{
@ -49,7 +50,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
struct TestBody { struct TestBody {
data: Bytes, data: Bytes,
chunk_size: usize, chunk_size: usize,
delay: ntex::rt::time::Delay, delay: Pin<Box<Sleep>>,
} }
impl TestBody { impl TestBody {
@ -57,7 +58,7 @@ impl TestBody {
TestBody { TestBody {
data, data,
chunk_size, chunk_size,
delay: ntex::rt::time::delay_for(std::time::Duration::from_millis(10)), delay: Box::pin(sleep(std::time::Duration::from_millis(10))),
} }
} }
} }
@ -71,7 +72,7 @@ impl futures::Stream for TestBody {
) -> Poll<Option<Self::Item>> { ) -> Poll<Option<Self::Item>> {
ready!(Pin::new(&mut self.delay).poll(cx)); ready!(Pin::new(&mut self.delay).poll(cx));
self.delay = ntex::rt::time::delay_for(std::time::Duration::from_millis(10)); self.delay = Box::pin(sleep(std::time::Duration::from_millis(10)));
let chunk_size = std::cmp::min(self.chunk_size, self.data.len()); let chunk_size = std::cmp::min(self.chunk_size, self.data.len());
let chunk = self.data.split_to(chunk_size); let chunk = self.data.split_to(chunk_size);
if chunk.is_empty() { if chunk.is_empty() {