diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index aaa1b007..78f2d45f 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -3,6 +3,17 @@ name: Checks on: [push, pull_request] jobs: + check: + name: Check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: stable + - run: + cargo check --tests --all --no-default-features --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" + clippy: name: Clippy runs-on: ubuntu-latest @@ -13,7 +24,7 @@ jobs: toolchain: stable components: clippy - run: - cargo test --all --no-default-features --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" + cargo clippy --tests --all --no-default-features --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" fmt: name: Rustfmt diff --git a/.github/workflows/cov.yml b/.github/workflows/cov.yml index be265444..c9f7a345 100644 --- a/.github/workflows/cov.yml +++ b/.github/workflows/cov.yml @@ -8,11 +8,6 @@ jobs: env: CARGO_TERM_COLOR: always steps: - - name: Free Disk Space - uses: jlumbroso/free-disk-space@main - with: - tool-cache: true - - uses: actions/checkout@v4 - name: Install Rust run: rustup update nightly @@ -26,18 +21,20 @@ jobs: - name: Clean coverage results run: cargo llvm-cov clean --workspace - - name: Code coverage (glommio) - continue-on-error: true - run: cargo llvm-cov --no-report --all --no-default-features --features="ntex/glommio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" -- --skip test_unhandled_data + - name: Code coverage (neon) + run: cargo llvm-cov --no-report --all --no-default-features --features="ntex/neon,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" + + - name: Code coverage (neon-uring) + run: cargo llvm-cov --no-report --all --no-default-features --features="ntex/neon-uring,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" - name: Code coverage (tokio) - run: cargo llvm-cov --no-report --all --no-default-features --features="ntex/tokio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" + run: cargo llvm-cov --no-report --all --no-default-features --features="ntex/tokio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" - name: Code coverage (compio) - run: cargo llvm-cov --no-report --all --no-default-features --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" + run: cargo llvm-cov --no-report --all --no-default-features --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" - name: Generate coverage report - run: cargo llvm-cov report --lcov --output-path lcov.info --ignore-filename-regex="ntex-compio|ntex-tokio|ntex-glommio|ntex-async-std" + run: cargo llvm-cov report --lcov --output-path lcov.info --ignore-filename-regex="ntex-compio|ntex-tokio" - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 863352e0..5297364c 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -16,11 +16,6 @@ jobs: runs-on: ubuntu-latest steps: - - name: Free Disk Space - uses: jlumbroso/free-disk-space@main - with: - tool-cache: true - - uses: actions/checkout@v4 - name: Install ${{ matrix.version }} @@ -44,21 +39,25 @@ jobs: path: ~/.cargo/git key: ${{ matrix.version }}-x86_64-unknown-linux-gnu-cargo-index-trimmed-${{ hashFiles('**/Cargo.lock') }} + - name: Run tests (neon) + timeout-minutes: 40 + run: | + cargo test --all --no-default-features --features="ntex/neon,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" + + - name: Run tests (neon-uring) + timeout-minutes: 40 + run: | + cargo test --all --no-default-features --features="ntex/neon-uring,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" + - name: Run tests (tokio) timeout-minutes: 40 run: | - cargo test --all --no-fail-fast --no-default-features --features="ntex/tokio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" + cargo test --all --no-fail-fast --no-default-features --features="ntex/tokio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" - name: Run tests (compio) timeout-minutes: 40 run: | - cargo test --all --no-default-features --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" - - - name: Run tests (async-std) - timeout-minutes: 40 - continue-on-error: true - run: | - cargo test --all --no-default-features --features="ntex/async-std,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" + cargo test --all --no-default-features --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" - name: Install cargo-cache continue-on-error: true diff --git a/.github/workflows/osx.yml b/.github/workflows/osx.yml index 5474c552..a926dd34 100644 --- a/.github/workflows/osx.yml +++ b/.github/workflows/osx.yml @@ -37,12 +37,16 @@ jobs: path: ~/.cargo/git key: ${{ matrix.version }}-aarch64-apple-darwin-cargo-index-trimmed-${{ hashFiles('**/Cargo.lock') }} + - name: Run tests (neon) + timeout-minutes: 40 + run: cargo test --all --no-default-features --no-fail-fast --features="ntex/neon,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" + - name: Run tests (tokio) - run: cargo test --all --no-default-features --no-fail-fast --features="ntex/tokio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" + run: cargo test --all --no-default-features --no-fail-fast --features="ntex/tokio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" - name: Run tests (compio) timeout-minutes: 40 - run: cargo test --all --no-default-features --no-fail-fast --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" + run: cargo test --all --no-default-features --no-fail-fast --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" - name: Install cargo-cache continue-on-error: true diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index b42e0f00..8902aa8f 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -63,8 +63,8 @@ jobs: - name: Run tests (tokio) run: | - cargo test --all --lib --no-default-features --no-fail-fast --features="ntex/tokio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" -- --skip test_timer + cargo test --all --lib --no-default-features --no-fail-fast --features="ntex/tokio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" -- --skip test_timer - name: Run tests (compio) run: | - cargo test --all --lib --no-default-features --no-fail-fast --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws,ntex/brotli" -- --skip test_timer + cargo test --all --lib --no-default-features --no-fail-fast --features="ntex/compio,ntex/cookie,ntex/url,ntex/compress,ntex/openssl,ntex/rustls,ntex/ws" -- --skip test_timer diff --git a/Cargo.toml b/Cargo.toml index 13846071..d9e97ef4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,12 +15,18 @@ members = [ "ntex-macros", "ntex-util", - "ntex-async-std", "ntex-compio", - "ntex-glommio", "ntex-tokio", ] +[workspace.package] +authors = ["ntex contributors "] +repository = "https://github.com/ntex-rs/ntex" +documentation = "https://docs.rs/ntex/" +license = "MIT OR Apache-2.0" +edition = "2021" +rust-version = "1.75" + [patch.crates-io] ntex = { path = "ntex" } ntex-bytes = { path = "ntex-bytes" } @@ -37,6 +43,28 @@ ntex-macros = { path = "ntex-macros" } ntex-util = { path = "ntex-util" } ntex-compio = { path = "ntex-compio" } -ntex-glommio = { path = "ntex-glommio" } ntex-tokio = { path = "ntex-tokio" } -ntex-async-std = { path = "ntex-async-std" } + +[workspace.dependencies] +async-channel = "2" +async-task = "4.5.0" +atomic-waker = "1.1" +core_affinity = "0.8" +bitflags = "2" +cfg_aliases = "0.2.1" +cfg-if = "1.0.0" +crossbeam-channel = "0.5.8" +crossbeam-queue = "0.3.8" +futures-util = "0.3.29" +fxhash = "0.2" +libc = "0.2.164" +log = "0.4" +io-uring = "0.7.4" +oneshot = "0.1" +polling = "3.7.4" +nohash-hasher = "0.2.0" +scoped-tls = "1.0.1" +slab = "0.4.9" +socket2 = "0.5.6" +windows-sys = "0.52.0" +thiserror = "1" diff --git a/README.md b/README.md index 73735566..443dbf0d 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

Framework for composable network services.

-[![build status](https://github.com/ntex-rs/ntex/workflows/CI%20%28Linux%29/badge.svg?branch=master&event=push)](https://github.com/ntex-rs/ntex/actions?query=workflow%3A"CI+(Linux)") +[![build status](https://github.com/ntex-rs/ntex/actions/workflows/linux.yml/badge.svg?branch=master&event=push)](https://github.com/ntex-rs/ntex/actions/workflows/linux.yml/badge.svg) [![crates.io](https://img.shields.io/crates/v/ntex.svg)](https://crates.io/crates/ntex) [![Documentation](https://img.shields.io/docsrs/ntex/latest)](https://docs.rs/ntex) [![Version](https://img.shields.io/badge/rustc-1.75+-lightgray.svg)](https://blog.rust-lang.org/2023/12/28/Rust-1.75.0.html) @@ -18,18 +18,18 @@ | Platform | Build Status | | ---------------- | ------------ | -| Linux | [![build status](https://github.com/ntex-rs/ntex/workflows/CI%20%28Linux%29/badge.svg?branch=master&event=push)](https://github.com/ntex-rs/ntex/actions?query=workflow%3A"CI+(Linux)") | -| macOS | [![build status](https://github.com/ntex-rs/ntex/workflows/CI%20%28OSX%29/badge.svg?branch=master&event=push)](https://github.com/ntex-rs/ntex/actions?query=workflow%3A"CI+(OSX)") | -| Windows | [![build status](https://github.com/ntex-rs/ntex/workflows/CI%20%28Windows%29/badge.svg?branch=master&event=push)](https://github.com/ntex-rs/ntex/actions?query=workflow%3A"CI+(Windows)") | +| Linux | [![build status](https://github.com/ntex-rs/ntex/actions/workflows/linux.yml/badge.svg?branch=master&event=push)](https://github.com/ntex-rs/ntex/actions/workflows/linux.yml/badge.svg) | +| macOS | [![build status](https://github.com/ntex-rs/ntex/actions/workflows/osx.yml/badge.svg?branch=master&event=push)](https://github.com/ntex-rs/ntex/actions/workflows/osx.yml/badge.svg) | +| Windows | [![build status](https://github.com/ntex-rs/ntex/actions/workflows/windows.yml/badge.svg?branch=master&event=push)](https://github.com/ntex-rs/ntex/actions/workflows/windows.yml/badge.svg) | ## Usage ntex supports multiple async runtimes, runtime must be selected as a feature. Available options are `compio`, `tokio`, -`glommio` or `async-std`. +`neon` or `neon-uring`. ```toml [dependencies] -ntex = { version = "2", features = ["tokio"] } +ntex = { version = "2", features = ["compio"] } ``` ## Documentation & community resources diff --git a/ntex-async-std/CHANGES.md b/ntex-async-std/CHANGES.md deleted file mode 100644 index 53ba72ac..00000000 --- a/ntex-async-std/CHANGES.md +++ /dev/null @@ -1,45 +0,0 @@ -# Changes - -## [0.4.0] - 2024-01-09 - -* Release - -## [0.4.0-b.0] - 2024-01-07 - -* Use "async fn" in trait for Service definition - -## [0.3.2] - 2023-11-22 - -* Replace async-oneshot with oneshot - -## [0.3.1] - 2023-11-12 - -* Optimize io read task - -## [0.3.0] - 2023-06-22 - -* Release v0.3.0 - -## [0.3.0-beta.0] - 2023-06-16 - -* Migrate to ntex-service 1.2 - -## [0.2.2] - 2023-01-26 - -* Update io api usage - -## [0.2.0] - 2023-01-04 - -* Release - -## [0.2.0-beta.0] - 2022-12-28 - -* Migrate to ntex-service 1.0 - -## [0.1.1] - 2022-01-30 - -* Update to ntex-io 0.1.7 - -## [0.1.0] - 2022-01-03 - -* Initial release diff --git a/ntex-async-std/Cargo.toml b/ntex-async-std/Cargo.toml deleted file mode 100644 index 21d9b93b..00000000 --- a/ntex-async-std/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "ntex-async-std" -version = "0.5.1" -authors = ["ntex contributors "] -description = "async-std intergration for ntex framework" -keywords = ["network", "framework", "async", "futures"] -homepage = "https://ntex.rs" -repository = "https://github.com/ntex-rs/ntex.git" -documentation = "https://docs.rs/ntex-rt-async-std/" -categories = ["network-programming", "asynchronous"] -license = "MIT OR Apache-2.0" -edition = "2021" - -[lib] -name = "ntex_async_std" -path = "src/lib.rs" - -[dependencies] -ntex-bytes = "0.1" -ntex-io = "2.5" -ntex-util = "2.0" -log = "0.4" -async-std = { version = "1", features = ["unstable"] } -oneshot = { version = "0.1", default-features = false, features = ["async"] } diff --git a/ntex-async-std/LICENSE-APACHE b/ntex-async-std/LICENSE-APACHE deleted file mode 120000 index 965b606f..00000000 --- a/ntex-async-std/LICENSE-APACHE +++ /dev/null @@ -1 +0,0 @@ -../LICENSE-APACHE \ No newline at end of file diff --git a/ntex-async-std/LICENSE-MIT b/ntex-async-std/LICENSE-MIT deleted file mode 120000 index 76219eb7..00000000 --- a/ntex-async-std/LICENSE-MIT +++ /dev/null @@ -1 +0,0 @@ -../LICENSE-MIT \ No newline at end of file diff --git a/ntex-async-std/src/io.rs b/ntex-async-std/src/io.rs deleted file mode 100644 index 7180aeae..00000000 --- a/ntex-async-std/src/io.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::{ - any, cell::RefCell, future::poll_fn, io, pin::Pin, task::ready, task::Context, - task::Poll, -}; - -use async_std::io::{Read as ARead, Write as AWrite}; -use ntex_bytes::{Buf, BufMut, BytesVec}; -use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf}; - -use crate::TcpStream; - -impl IoStream for TcpStream { - fn start(self, read: ReadContext, write: WriteContext) -> Option> { - let mut rio = Read(RefCell::new(self.clone())); - async_std::task::spawn_local(async move { - read.handle(&mut rio).await; - }); - let mut wio = Write(RefCell::new(self.clone())); - async_std::task::spawn_local(async move { - write.handle(&mut wio).await; - }); - Some(Box::new(self)) - } -} - -impl Handle for TcpStream { - fn query(&self, id: any::TypeId) -> Option> { - if id == any::TypeId::of::() { - if let Ok(addr) = self.0.peer_addr() { - return Some(Box::new(types::PeerAddr(addr))); - } - } - None - } -} - -/// Read io task -struct Read(RefCell); - -impl ntex_io::AsyncRead for Read { - async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { - // read data from socket - let result = poll_fn(|cx| { - let mut io = self.0.borrow_mut(); - poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) - }) - .await; - (buf, result) - } -} - -struct Write(RefCell); - -impl ntex_io::AsyncWrite for Write { - #[inline] - async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { - poll_fn(|cx| { - if let Some(mut b) = buf.take() { - let result = flush_io(&mut self.0.borrow_mut().0, &mut b, cx); - buf.set(b); - result - } else { - Poll::Ready(Ok(())) - } - }) - .await - } - - #[inline] - async fn flush(&mut self) -> io::Result<()> { - Ok(()) - } - - #[inline] - async fn shutdown(&mut self) -> io::Result<()> { - self.0.borrow().0.shutdown(std::net::Shutdown::Both) - } -} - -/// Flush write buffer to underlying I/O stream. -pub(super) fn flush_io( - io: &mut T, - buf: &mut BytesVec, - cx: &mut Context<'_>, -) -> Poll> { - let len = buf.len(); - - if len != 0 { - // log::trace!("flushing framed transport: {:?}", buf.len()); - - let mut written = 0; - let result = loop { - break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("Disconnected during flush, written {}", written); - Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))) - } else { - written += n; - if written == len { - buf.clear(); - Poll::Ready(Ok(())) - } else { - continue; - } - } - } - Poll::Pending => { - // remove written data - buf.advance(written); - Poll::Pending - } - Poll::Ready(Err(e)) => { - log::trace!("Error during flush: {}", e); - Poll::Ready(Err(e)) - } - }; - }; - // log::trace!("flushed {} bytes", written); - - // flush - if written > 0 { - match Pin::new(&mut *io).poll_flush(cx) { - Poll::Ready(Ok(_)) => result, - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - log::trace!("error during flush: {}", e); - Poll::Ready(Err(e)) - } - } - } else { - result - } - } else { - Poll::Ready(Ok(())) - } -} - -pub fn poll_read_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut BytesVec, -) -> Poll> { - let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) }; - let n = ready!(io.poll_read(cx, dst))?; - - // Safety: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by Read::poll_read() api - unsafe { - buf.advance_mut(n); - } - - Poll::Ready(Ok(n)) -} - -#[cfg(unix)] -mod unixstream { - use super::*; - use crate::UnixStream; - - impl IoStream for UnixStream { - fn start(self, read: ReadContext, write: WriteContext) -> Option> { - let mut rio = Read(RefCell::new(self.clone())); - async_std::task::spawn_local(async move { - read.handle(&mut rio).await; - }); - let mut wio = Write(RefCell::new(self)); - async_std::task::spawn_local(async move { - write.handle(&mut wio).await; - }); - None - } - } - - /// Read io task - struct Read(RefCell); - - impl ntex_io::AsyncRead for Read { - async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { - // read data from socket - let result = poll_fn(|cx| { - let mut io = self.0.borrow_mut(); - poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) - }) - .await; - (buf, result) - } - } - - struct Write(RefCell); - - impl ntex_io::AsyncWrite for Write { - #[inline] - async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { - poll_fn(|cx| { - if let Some(mut b) = buf.take() { - let result = flush_io(&mut self.0.borrow_mut().0, &mut b, cx); - buf.set(b); - result - } else { - Poll::Ready(Ok(())) - } - }) - .await - } - - #[inline] - async fn flush(&mut self) -> io::Result<()> { - Ok(()) - } - - #[inline] - async fn shutdown(&mut self) -> io::Result<()> { - self.0.borrow().0.shutdown(std::net::Shutdown::Both) - } - } -} diff --git a/ntex-async-std/src/lib.rs b/ntex-async-std/src/lib.rs deleted file mode 100644 index e347d282..00000000 --- a/ntex-async-std/src/lib.rs +++ /dev/null @@ -1,64 +0,0 @@ -use std::{io::Result, net, net::SocketAddr}; - -use ntex_bytes::PoolRef; -use ntex_io::Io; - -mod io; -mod signals; - -pub use self::signals::{signal, Signal}; - -#[derive(Clone)] -struct TcpStream(async_std::net::TcpStream); - -#[cfg(unix)] -#[derive(Clone)] -struct UnixStream(async_std::os::unix::net::UnixStream); - -/// Opens a TCP connection to a remote host. -pub async fn tcp_connect(addr: SocketAddr) -> Result { - let sock = async_std::net::TcpStream::connect(addr).await?; - sock.set_nodelay(true)?; - Ok(Io::new(TcpStream(sock))) -} - -/// Opens a TCP connection to a remote host and use specified memory pool. -pub async fn tcp_connect_in(addr: SocketAddr, pool: PoolRef) -> Result { - let sock = async_std::net::TcpStream::connect(addr).await?; - sock.set_nodelay(true)?; - Ok(Io::with_memory_pool(TcpStream(sock), pool)) -} - -#[cfg(unix)] -/// Opens a unix stream connection. -pub async fn unix_connect

(addr: P) -> Result -where - P: AsRef, -{ - let sock = async_std::os::unix::net::UnixStream::connect(addr).await?; - Ok(Io::new(UnixStream(sock))) -} - -#[cfg(unix)] -/// Opens a unix stream connection and specified memory pool. -pub async fn unix_connect_in

(addr: P, pool: PoolRef) -> Result -where - P: AsRef, -{ - let sock = async_std::os::unix::net::UnixStream::connect(addr).await?; - Ok(Io::with_memory_pool(UnixStream(sock), pool)) -} - -/// Convert std TcpStream to async-std's TcpStream -pub fn from_tcp_stream(stream: net::TcpStream) -> Result { - stream.set_nonblocking(true)?; - stream.set_nodelay(true)?; - Ok(Io::new(TcpStream(async_std::net::TcpStream::from(stream)))) -} - -#[cfg(unix)] -/// Convert std UnixStream to async-std's UnixStream -pub fn from_unix_stream(stream: std::os::unix::net::UnixStream) -> Result { - stream.set_nonblocking(true)?; - Ok(Io::new(UnixStream(From::from(stream)))) -} diff --git a/ntex-async-std/src/signals.rs b/ntex-async-std/src/signals.rs deleted file mode 100644 index d90135ad..00000000 --- a/ntex-async-std/src/signals.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::{cell::RefCell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll}; - -thread_local! { - static SRUN: RefCell = const { RefCell::new(false) }; - static SHANDLERS: Rc>>> = Default::default(); -} - -/// Different types of process signals -#[derive(PartialEq, Eq, Clone, Copy, Debug)] -pub enum Signal { - /// SIGHUP - Hup, - /// SIGINT - Int, - /// SIGTERM - Term, - /// SIGQUIT - Quit, -} - -/// Register signal handler. -/// -/// Signals are handled by oneshots, you have to re-register -/// after each signal. -pub fn signal() -> Option> { - if !SRUN.with(|v| *v.borrow()) { - async_std::task::spawn_local(Signals::new()); - } - SHANDLERS.with(|handlers| { - let (tx, rx) = oneshot::channel(); - handlers.borrow_mut().push(tx); - Some(rx) - }) -} - -struct Signals {} - -impl Signals { - pub(super) fn new() -> Signals { - Self {} - } -} - -impl Future for Signals { - type Output = (); - - fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { - Poll::Ready(()) - } -} diff --git a/ntex-bytes/src/bytes.rs b/ntex-bytes/src/bytes.rs index 77267307..b7eb8eea 100644 --- a/ntex-bytes/src/bytes.rs +++ b/ntex-bytes/src/bytes.rs @@ -3779,7 +3779,7 @@ impl PartialEq for [u8; N] { } } -impl<'a, const N: usize> PartialEq for &'a [u8; N] { +impl PartialEq for &[u8; N] { fn eq(&self, other: &BytesMut) -> bool { *other == *self } @@ -3878,7 +3878,7 @@ impl PartialEq for [u8; N] { } } -impl<'a, const N: usize> PartialEq for &'a [u8; N] { +impl PartialEq for &[u8; N] { fn eq(&self, other: &Bytes) -> bool { *other == *self } @@ -4076,7 +4076,7 @@ impl PartialEq for [u8; N] { } } -impl<'a, const N: usize> PartialEq for &'a [u8; N] { +impl PartialEq for &[u8; N] { fn eq(&self, other: &BytesVec) -> bool { *other == *self } diff --git a/ntex-bytes/src/hex.rs b/ntex-bytes/src/hex.rs index 46ad1e1f..109b2462 100644 --- a/ntex-bytes/src/hex.rs +++ b/ntex-bytes/src/hex.rs @@ -3,7 +3,7 @@ use std::fmt::{Formatter, LowerHex, Result, UpperHex}; struct BytesRef<'a>(&'a [u8]); -impl<'a> LowerHex for BytesRef<'a> { +impl LowerHex for BytesRef<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> Result { for b in self.0 { write!(f, "{b:02x}")?; @@ -12,7 +12,7 @@ impl<'a> LowerHex for BytesRef<'a> { } } -impl<'a> UpperHex for BytesRef<'a> { +impl UpperHex for BytesRef<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> Result { for b in self.0 { write!(f, "{b:02X}")?; diff --git a/ntex-compio/CHANGES.md b/ntex-compio/CHANGES.md index e31f8cc6..bc224118 100644 --- a/ntex-compio/CHANGES.md +++ b/ntex-compio/CHANGES.md @@ -1,5 +1,17 @@ # Changes +## [0.2.4] - 2024-12-01 + +* Depend on individual compio packages + +## [0.2.3] - 2024-11-27 + +* Disable default features + +## [0.2.2] - 2024-11-25 + +* Update to compio 0.13 + ## [0.2.1] - 2024-10-31 * It's not required to close compio sockets explicitly #444 diff --git a/ntex-compio/Cargo.toml b/ntex-compio/Cargo.toml index 4799ff8f..aae3509a 100644 --- a/ntex-compio/Cargo.toml +++ b/ntex-compio/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-compio" -version = "0.2.1" +version = "0.2.4" authors = ["ntex contributors "] description = "compio runtime intergration for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -20,5 +20,11 @@ path = "src/lib.rs" ntex-bytes = "0.1" ntex-io = "2.5" ntex-util = "2" +ntex-rt = "0.4" log = "0.4" -compio = { version = "0.12.0", features = ["macros", "io", "runtime"] } + +compio-buf = "0.5" +compio-io = "0.5" +compio-net = "0.6" +compio-driver = "0.6" +compio-runtime = { version = "0.6", features = ["io-uring", "polling", "event"] } diff --git a/ntex-compio/src/io.rs b/ntex-compio/src/io.rs index 8e80c860..81d41429 100644 --- a/ntex-compio/src/io.rs +++ b/ntex-compio/src/io.rs @@ -1,15 +1,15 @@ use std::{any, io}; -use compio::buf::{BufResult, IoBuf, IoBufMut, SetBufInit}; -use compio::io::{AsyncRead, AsyncWrite}; -use compio::net::TcpStream; +use compio_buf::{BufResult, IoBuf, IoBufMut, SetBufInit}; +use compio_io::{AsyncRead, AsyncWrite}; +use compio_net::TcpStream; use ntex_bytes::{Buf, BufMut, BytesVec}; use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf}; impl IoStream for crate::TcpStream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { let io = self.0.clone(); - compio::runtime::spawn(async move { run(io.clone(), &read, write).await }).detach(); + compio_runtime::spawn(async move { run(io.clone(), &read, write).await }).detach(); Some(Box::new(HandleWrapper(self.0))) } @@ -18,7 +18,7 @@ impl IoStream for crate::TcpStream { #[cfg(unix)] impl IoStream for crate::UnixStream { fn start(self, read: ReadContext, write: WriteContext) -> Option> { - compio::runtime::spawn(async move { run(self.0.clone(), &read, write).await }) + compio_runtime::spawn(async move { run(self.0.clone(), &read, write).await }) .detach(); None @@ -75,7 +75,7 @@ async fn run( write: WriteContext, ) { let mut wr_io = WriteIo(io.clone()); - let wr_task = compio::runtime::spawn(async move { + let wr_task = compio_runtime::spawn(async move { write.handle(&mut wr_io).await; log::debug!("{} Write task is stopped", write.tag()); }); diff --git a/ntex-compio/src/lib.rs b/ntex-compio/src/lib.rs index 4d034824..9b4dbae6 100644 --- a/ntex-compio/src/lib.rs +++ b/ntex-compio/src/lib.rs @@ -6,21 +6,21 @@ use ntex_io::Io; mod io; /// Tcp stream wrapper for compio TcpStream -struct TcpStream(compio::net::TcpStream); +struct TcpStream(compio_net::TcpStream); #[cfg(unix)] /// Tcp stream wrapper for compio UnixStream -struct UnixStream(compio::net::UnixStream); +struct UnixStream(compio_net::UnixStream); /// Opens a TCP connection to a remote host. pub async fn tcp_connect(addr: SocketAddr) -> Result { - let sock = compio::net::TcpStream::connect(addr).await?; + let sock = compio_net::TcpStream::connect(addr).await?; Ok(Io::new(TcpStream(sock))) } /// Opens a TCP connection to a remote host and use specified memory pool. pub async fn tcp_connect_in(addr: SocketAddr, pool: PoolRef) -> Result { - let sock = compio::net::TcpStream::connect(addr).await?; + let sock = compio_net::TcpStream::connect(addr).await?; Ok(Io::with_memory_pool(TcpStream(sock), pool)) } @@ -30,7 +30,7 @@ pub async fn unix_connect<'a, P>(addr: P) -> Result where P: AsRef + 'a, { - let sock = compio::net::UnixStream::connect(addr).await?; + let sock = compio_net::UnixStream::connect(addr).await?; Ok(Io::new(UnixStream(sock))) } @@ -40,22 +40,20 @@ pub async fn unix_connect_in<'a, P>(addr: P, pool: PoolRef) -> Result where P: AsRef + 'a, { - let sock = compio::net::UnixStream::connect(addr).await?; + let sock = compio_net::UnixStream::connect(addr).await?; Ok(Io::with_memory_pool(UnixStream(sock), pool)) } /// Convert std TcpStream to tokio's TcpStream pub fn from_tcp_stream(stream: net::TcpStream) -> Result { stream.set_nodelay(true)?; - Ok(Io::new(TcpStream(compio::net::TcpStream::from_std( - stream, - )?))) + Ok(Io::new(TcpStream(compio_net::TcpStream::from_std(stream)?))) } #[cfg(unix)] /// Convert std UnixStream to tokio's UnixStream pub fn from_unix_stream(stream: std::os::unix::net::UnixStream) -> Result { - Ok(Io::new(UnixStream(compio::net::UnixStream::from_std( + Ok(Io::new(UnixStream(compio_net::UnixStream::from_std( stream, )?))) } diff --git a/ntex-glommio/CHANGES.md b/ntex-glommio/CHANGES.md deleted file mode 100644 index 30faefde..00000000 --- a/ntex-glommio/CHANGES.md +++ /dev/null @@ -1,57 +0,0 @@ -# Changes - -## [0.5.2] - 2024-09-xx - -* Update to glommio v0.9 - -## [0.4.0] - 2024-01-09 - -* Release - -## [0.4.0-b.0] - 2024-01-07 - -* Use "async fn" in trait for Service definition - -## [0.3.1] - 2023-11-22 - -* Replace async-oneshot with oneshot - -## [0.3.0] - 2023-06-22 - -* Release v0.3.0 - -## [0.3.0-beta.0] - 2023-06-16 - -* Migrate to ntex-service 1.2 - -## [0.2.4] - 2023-05-30 - -* Fix borrow mut panic #204 - -## [0.2.3] - 2023-04-11 - -* Chore upgrade glommio to 0.8 - -## [0.2.2] - 2023-01-26 - -* Update io api usage - -## [0.2.0] - 2023-01-04 - -* Release - -## [0.2.0-beta.0] - 2022-12-28 - -* Migrate to ntex-service 1.0 - -## [0.1.2] - 2022-02-20 - -* Upgrade to glommio 0.7 - -## [0.1.1] - 2022-01-30 - -* Update to ntex-io 0.1.7 - -## [0.1.0] - 2022-01-17 - -* Initial release diff --git a/ntex-glommio/Cargo.toml b/ntex-glommio/Cargo.toml deleted file mode 100644 index bb9ec0ed..00000000 --- a/ntex-glommio/Cargo.toml +++ /dev/null @@ -1,27 +0,0 @@ -[package] -name = "ntex-glommio" -version = "0.5.2" -authors = ["ntex contributors "] -description = "glommio intergration for ntex framework" -keywords = ["network", "framework", "async", "futures"] -homepage = "https://ntex.rs" -repository = "https://github.com/ntex-rs/ntex.git" -documentation = "https://docs.rs/ntex-rt-glommio/" -categories = ["network-programming", "asynchronous"] -license = "MIT OR Apache-2.0" -edition = "2021" - -[lib] -name = "ntex_glommio" -path = "src/lib.rs" - -[dependencies] -ntex-bytes = "0.1" -ntex-io = "2.5" -ntex-util = "2.0" -futures-lite = "2.2" -log = "0.4" -oneshot = { version = "0.1", default-features = false, features = ["async"] } - -[target.'cfg(target_os = "linux")'.dependencies] -glommio = "0.9" diff --git a/ntex-glommio/LICENSE-APACHE b/ntex-glommio/LICENSE-APACHE deleted file mode 120000 index 965b606f..00000000 --- a/ntex-glommio/LICENSE-APACHE +++ /dev/null @@ -1 +0,0 @@ -../LICENSE-APACHE \ No newline at end of file diff --git a/ntex-glommio/LICENSE-MIT b/ntex-glommio/LICENSE-MIT deleted file mode 120000 index 76219eb7..00000000 --- a/ntex-glommio/LICENSE-MIT +++ /dev/null @@ -1 +0,0 @@ -../LICENSE-MIT \ No newline at end of file diff --git a/ntex-glommio/src/io.rs b/ntex-glommio/src/io.rs deleted file mode 100644 index 09fc0616..00000000 --- a/ntex-glommio/src/io.rs +++ /dev/null @@ -1,205 +0,0 @@ -use std::{any, future::poll_fn, io, pin::Pin, task::ready, task::Context, task::Poll}; - -use futures_lite::io::{AsyncRead, AsyncWrite}; -use ntex_bytes::{Buf, BufMut, BytesVec}; -use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf}; - -use crate::net_impl::{TcpStream, UnixStream}; - -impl IoStream for TcpStream { - fn start(self, read: ReadContext, write: WriteContext) -> Option> { - let mut rio = Read(self.clone()); - glommio::spawn_local(async move { read.handle(&mut rio).await }).detach(); - let mut wio = Write(self.clone()); - glommio::spawn_local(async move { write.handle(&mut wio).await }).detach(); - Some(Box::new(self)) - } -} - -impl IoStream for UnixStream { - fn start(self, read: ReadContext, write: WriteContext) -> Option> { - let mut rio = UnixRead(self.clone()); - glommio::spawn_local(async move { - read.handle(&mut rio).await; - }) - .detach(); - let mut wio = UnixWrite(self); - glommio::spawn_local(async move { write.handle(&mut wio).await }).detach(); - None - } -} - -impl Handle for TcpStream { - fn query(&self, id: any::TypeId) -> Option> { - if id == any::TypeId::of::() { - if let Ok(addr) = self.0.borrow().peer_addr() { - return Some(Box::new(types::PeerAddr(addr))); - } - } - None - } -} - -/// Read io task -struct Read(TcpStream); - -impl ntex_io::AsyncRead for Read { - async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { - // read data from socket - let result = poll_fn(|cx| { - let mut io = self.0 .0.borrow_mut(); - poll_read_buf(Pin::new(&mut *io), cx, &mut buf) - }) - .await; - (buf, result) - } -} - -struct Write(TcpStream); - -impl ntex_io::AsyncWrite for Write { - #[inline] - async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { - poll_fn(|cx| { - if let Some(mut b) = buf.take() { - let result = flush_io(&mut *self.0 .0.borrow_mut(), &mut b, cx); - buf.set(b); - result - } else { - Poll::Ready(Ok(())) - } - }) - .await - } - - #[inline] - async fn flush(&mut self) -> io::Result<()> { - Ok(()) - } - - #[inline] - async fn shutdown(&mut self) -> io::Result<()> { - poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await - } -} - -struct UnixRead(UnixStream); - -impl ntex_io::AsyncRead for UnixRead { - async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result) { - // read data from socket - let result = poll_fn(|cx| { - let mut io = self.0 .0.borrow_mut(); - poll_read_buf(Pin::new(&mut *io), cx, &mut buf) - }) - .await; - (buf, result) - } -} - -struct UnixWrite(UnixStream); - -impl ntex_io::AsyncWrite for UnixWrite { - #[inline] - async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> { - poll_fn(|cx| { - if let Some(mut b) = buf.take() { - let result = flush_io(&mut *self.0 .0.borrow_mut(), &mut b, cx); - buf.set(b); - result - } else { - Poll::Ready(Ok(())) - } - }) - .await - } - - #[inline] - async fn flush(&mut self) -> io::Result<()> { - Ok(()) - } - - #[inline] - async fn shutdown(&mut self) -> io::Result<()> { - poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await - } -} - -/// Flush write buffer to underlying I/O stream. -pub(super) fn flush_io( - io: &mut T, - buf: &mut BytesVec, - cx: &mut Context<'_>, -) -> Poll> { - let len = buf.len(); - - if len != 0 { - // log::trace!("flushing framed transport: {:?}", buf.len()); - - let mut written = 0; - let result = loop { - break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("Disconnected during flush, written {}", written); - Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ))) - } else { - written += n; - if written == len { - buf.clear(); - Poll::Ready(Ok(())) - } else { - continue; - } - } - } - Poll::Pending => { - // remove written data - buf.advance(written); - Poll::Pending - } - Poll::Ready(Err(e)) => { - log::trace!("Error during flush: {}", e); - Poll::Ready(Err(e)) - } - }; - }; - // log::trace!("flushed {} bytes", written); - - // flush - if written > 0 { - match Pin::new(&mut *io).poll_flush(cx) { - Poll::Ready(Ok(_)) => result, - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - log::trace!("error during flush: {}", e); - Poll::Ready(Err(e)) - } - } - } else { - result - } - } else { - Poll::Ready(Ok(())) - } -} - -pub fn poll_read_buf( - io: Pin<&mut T>, - cx: &mut Context<'_>, - buf: &mut BytesVec, -) -> Poll> { - let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) }; - let n = ready!(io.poll_read(cx, dst))?; - - // Safety: This is guaranteed to be the number of initialized (and read) - // bytes due to the invariants provided by Read::poll_read() api - unsafe { - buf.advance_mut(n); - } - - Poll::Ready(Ok(n)) -} diff --git a/ntex-glommio/src/lib.rs b/ntex-glommio/src/lib.rs deleted file mode 100644 index 8c8885b5..00000000 --- a/ntex-glommio/src/lib.rs +++ /dev/null @@ -1,90 +0,0 @@ -#[cfg(target_os = "linux")] -mod io; -#[cfg(target_os = "linux")] -mod signals; - -#[cfg(target_os = "linux")] -pub use self::signals::{signal, Signal}; - -#[cfg(target_os = "linux")] -mod net_impl { - use std::os::unix::io::{FromRawFd, IntoRawFd}; - use std::{cell::RefCell, io::Result, net, net::SocketAddr, rc::Rc}; - - use ntex_bytes::PoolRef; - use ntex_io::Io; - - #[derive(Clone)] - pub(crate) struct TcpStream(pub(crate) Rc>); - - impl TcpStream { - fn new(io: glommio::net::TcpStream) -> Self { - Self(Rc::new(RefCell::new(io))) - } - } - - #[derive(Clone)] - pub(crate) struct UnixStream(pub(crate) Rc>); - - impl UnixStream { - fn new(io: glommio::net::UnixStream) -> Self { - Self(Rc::new(RefCell::new(io))) - } - } - - /// Opens a TCP connection to a remote host. - pub async fn tcp_connect(addr: SocketAddr) -> Result { - let sock = glommio::net::TcpStream::connect(addr).await?; - sock.set_nodelay(true)?; - Ok(Io::new(TcpStream::new(sock))) - } - - /// Opens a TCP connection to a remote host and use specified memory pool. - pub async fn tcp_connect_in(addr: SocketAddr, pool: PoolRef) -> Result { - let sock = glommio::net::TcpStream::connect(addr).await?; - sock.set_nodelay(true)?; - Ok(Io::with_memory_pool(TcpStream::new(sock), pool)) - } - - /// Opens a unix stream connection. - pub async fn unix_connect

(addr: P) -> Result - where - P: AsRef, - { - let sock = glommio::net::UnixStream::connect(addr).await?; - Ok(Io::new(UnixStream::new(sock))) - } - - /// Opens a unix stream connection and specified memory pool. - pub async fn unix_connect_in

(addr: P, pool: PoolRef) -> Result - where - P: AsRef, - { - let sock = glommio::net::UnixStream::connect(addr).await?; - Ok(Io::with_memory_pool(UnixStream::new(sock), pool)) - } - - /// Convert std TcpStream to glommio's TcpStream - pub fn from_tcp_stream(stream: net::TcpStream) -> Result { - stream.set_nonblocking(true)?; - stream.set_nodelay(true)?; - unsafe { - Ok(Io::new(TcpStream::new( - glommio::net::TcpStream::from_raw_fd(stream.into_raw_fd()), - ))) - } - } - - /// Convert std UnixStream to glommio's UnixStream - pub fn from_unix_stream(stream: std::os::unix::net::UnixStream) -> Result { - stream.set_nonblocking(true)?; - unsafe { - Ok(Io::new(UnixStream::new( - glommio::net::UnixStream::from_raw_fd(stream.into_raw_fd()), - ))) - } - } -} - -#[cfg(target_os = "linux")] -pub use self::net_impl::*; diff --git a/ntex-glommio/src/signals.rs b/ntex-glommio/src/signals.rs deleted file mode 100644 index 366ace9c..00000000 --- a/ntex-glommio/src/signals.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::{cell::RefCell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll}; - -thread_local! { - static SRUN: RefCell = const { RefCell::new(false) }; - static SHANDLERS: Rc>>> = Default::default(); -} - -/// Different types of process signals -#[derive(PartialEq, Clone, Copy, Debug)] -pub enum Signal { - /// SIGHUP - Hup, - /// SIGINT - Int, - /// SIGTERM - Term, - /// SIGQUIT - Quit, -} - -/// Register signal handler. -/// -/// Signals are handled by oneshots, you have to re-register -/// after each signal. -pub fn signal() -> Option> { - if !SRUN.with(|v| *v.borrow()) { - glommio::spawn_local(Signals::new()).detach(); - } - SHANDLERS.with(|handlers| { - let (tx, rx) = oneshot::channel(); - handlers.borrow_mut().push(tx); - Some(rx) - }) -} - -struct Signals {} - -impl Signals { - pub(super) fn new() -> Signals { - Self {} - } -} - -impl Future for Signals { - type Output = (); - - fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { - Poll::Ready(()) - } -} diff --git a/ntex-http/CHANGES.md b/ntex-http/CHANGES.md index 875dd380..deac9c2c 100644 --- a/ntex-http/CHANGES.md +++ b/ntex-http/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.1.13] - 2024-01-30 + +* Move body related types from ntex::http + ## [0.1.12] - 2024-01-16 * Update http dependency diff --git a/ntex-http/Cargo.toml b/ntex-http/Cargo.toml index 40929a7a..ccb4b5a6 100644 --- a/ntex-http/Cargo.toml +++ b/ntex-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-http" -version = "0.1.12" +version = "0.1.13" authors = ["ntex contributors "] description = "Http types for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -20,9 +20,14 @@ http = "1" log = "0.4" fxhash = "0.2.1" itoa = "1.0.4" -ntex-bytes = "0.1.21" +ntex-bytes = "0.1" serde = "1" +futures-core = { version = "0.3", default-features = false, features = ["alloc"] } [dev-dependencies] bincode = "1" serde_json = "1" +ntex = "2" +ntex-util = "2" +ntex-macros = "0.1.3" +futures-util = { version = "0.3", default-features = false, features = ["alloc"] } diff --git a/ntex/src/http/body.rs b/ntex-http/src/body.rs similarity index 94% rename from ntex/src/http/body.rs rename to ntex-http/src/body.rs index 048854e7..f00f0d18 100644 --- a/ntex/src/http/body.rs +++ b/ntex-http/src/body.rs @@ -1,8 +1,10 @@ +//! Traits and structures to aid consuming and writing HTTP payloads. use std::{ error::Error, fmt, marker::PhantomData, mem, pin::Pin, task::Context, task::Poll, }; -use crate::util::{Bytes, BytesMut, Stream}; +use futures_core::Stream; +use ntex_bytes::{Bytes, BytesMut}; #[derive(Debug, PartialEq, Eq, Copy, Clone)] /// Body size hint @@ -19,8 +21,9 @@ impl BodySize { } } -/// Type that provides this trait can be streamed to a peer. +/// Interface for types that can be streamed to a peer. pub trait MessageBody: 'static { + /// Message body size hind fn size(&self) -> BodySize; fn poll_next_chunk( @@ -30,10 +33,12 @@ pub trait MessageBody: 'static { } impl MessageBody for () { + #[inline] fn size(&self) -> BodySize { BodySize::Empty } + #[inline] fn poll_next_chunk( &mut self, _: &mut Context<'_>, @@ -43,10 +48,12 @@ impl MessageBody for () { } impl MessageBody for Box { + #[inline] fn size(&self) -> BodySize { self.as_ref().size() } + #[inline] fn poll_next_chunk( &mut self, cx: &mut Context<'_>, @@ -56,6 +63,7 @@ impl MessageBody for Box { } #[derive(Debug)] +/// Represents http response body pub enum ResponseBody { Body(B), Other(Body), @@ -86,10 +94,12 @@ impl From for ResponseBody { } impl ResponseBody { + #[inline] pub fn new(body: B) -> Self { ResponseBody::Body(body) } + #[inline] pub fn take_body(&mut self) -> ResponseBody { std::mem::replace(self, ResponseBody::Other(Body::None)) } @@ -106,6 +116,7 @@ impl ResponseBody { } impl MessageBody for ResponseBody { + #[inline] fn size(&self) -> BodySize { match self { ResponseBody::Body(ref body) => body.size(), @@ -113,6 +124,7 @@ impl MessageBody for ResponseBody { } } + #[inline] fn poll_next_chunk( &mut self, cx: &mut Context<'_>, @@ -154,12 +166,13 @@ impl Body { } /// Create body from generic message body. - pub fn from_message(body: B) -> Body { + pub fn from_message(body: B) -> Body { Body::Message(Box::new(body)) } } impl MessageBody for Body { + #[inline] fn size(&self) -> BodySize { match self { Body::None => BodySize::None, @@ -253,12 +266,6 @@ impl From for Body { } } -impl From for Body { - fn from(v: serde_json::Value) -> Body { - Body::Bytes(v.to_string().into()) - } -} - impl From> for Body where S: Stream>> + Unpin + 'static, @@ -551,11 +558,12 @@ where #[cfg(test)] mod tests { - use futures_util::stream; use std::{future::poll_fn, io}; + use futures_util::stream; + use ntex_util::future::Ready; + use super::*; - use crate::util::Ready; impl Body { pub(crate) fn get_ref(&self) -> &[u8] { @@ -566,16 +574,7 @@ mod tests { } } - impl ResponseBody { - pub(crate) fn get_ref(&self) -> &[u8] { - match *self { - ResponseBody::Body(ref b) => b.get_ref(), - ResponseBody::Other(ref b) => b.get_ref(), - } - } - } - - #[crate::rt_test] + #[ntex::test] async fn test_static_str() { assert_eq!(Body::from("").size(), BodySize::Sized(0)); assert_eq!(Body::from("test").size(), BodySize::Sized(4)); @@ -593,7 +592,7 @@ mod tests { assert!(poll_fn(|cx| "".poll_next_chunk(cx)).await.is_none()); } - #[crate::rt_test] + #[ntex::test] async fn test_static_bytes() { assert_eq!(Body::from(b"test".as_ref()).size(), BodySize::Sized(4)); assert_eq!(Body::from(b"test".as_ref()).get_ref(), b"test"); @@ -615,7 +614,7 @@ mod tests { assert!(poll_fn(|cx| (&b""[..]).poll_next_chunk(cx)).await.is_none()); } - #[crate::rt_test] + #[ntex::test] async fn test_vec() { assert_eq!(Body::from(Vec::from("test")).size(), BodySize::Sized(4)); assert_eq!(Body::from(Vec::from("test")).get_ref(), b"test"); @@ -640,7 +639,7 @@ mod tests { .is_none()); } - #[crate::rt_test] + #[ntex::test] async fn test_bytes() { let mut b = Bytes::from("test"); assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); @@ -654,7 +653,7 @@ mod tests { assert!(poll_fn(|cx| b.poll_next_chunk(cx)).await.is_none(),); } - #[crate::rt_test] + #[ntex::test] async fn test_bytes_mut() { let mut b = Body::from(BytesMut::from("test")); assert_eq!(b.size(), BodySize::Sized(4)); @@ -675,7 +674,7 @@ mod tests { assert!(poll_fn(|cx| b.poll_next_chunk(cx)).await.is_none(),); } - #[crate::rt_test] + #[ntex::test] async fn test_string() { let mut b = "test".to_owned(); assert_eq!(Body::from(b.clone()).size(), BodySize::Sized(4)); @@ -691,20 +690,20 @@ mod tests { assert!(poll_fn(|cx| b.poll_next_chunk(cx)).await.is_none(),); } - #[crate::rt_test] + #[ntex::test] async fn test_unit() { assert_eq!(().size(), BodySize::Empty); assert!(poll_fn(|cx| ().poll_next_chunk(cx)).await.is_none()); } - #[crate::rt_test] + #[ntex::test] async fn test_box() { let mut val = Box::new(()); assert_eq!(val.size(), BodySize::Empty); assert!(poll_fn(|cx| val.poll_next_chunk(cx)).await.is_none()); } - #[crate::rt_test] + #[ntex::test] #[allow(clippy::eq_op)] async fn test_body_eq() { assert!(Body::None == Body::None); @@ -717,27 +716,14 @@ mod tests { assert!(Body::Bytes(Bytes::from_static(b"1")) != Body::None); } - #[crate::rt_test] + #[ntex::test] async fn test_body_debug() { assert!(format!("{:?}", Body::None).contains("Body::None")); assert!(format!("{:?}", Body::Empty).contains("Body::Empty")); assert!(format!("{:?}", Body::Bytes(Bytes::from_static(b"1"))).contains('1')); } - #[crate::rt_test] - async fn test_serde_json() { - use serde_json::json; - assert_eq!( - Body::from(serde_json::Value::String("test".into())).size(), - BodySize::Sized(6) - ); - assert_eq!( - Body::from(json!({"test-key":"test-value"})).size(), - BodySize::Sized(25) - ); - } - - #[crate::rt_test] + #[ntex::test] async fn body_stream() { let st = BodyStream::new(stream::once(Ready::<_, io::Error>::Ok(Bytes::from("1")))); assert!(format!("{:?}", st).contains("BodyStream")); @@ -749,7 +735,7 @@ mod tests { assert!(res.as_ref().is_some()); } - #[crate::rt_test] + #[ntex::test] async fn boxed_body_stream() { let st = BoxedBodyStream::new(stream::once(Ready::<_, Box>::Ok( Bytes::from("1"), @@ -763,7 +749,7 @@ mod tests { assert!(res.as_ref().is_some()); } - #[crate::rt_test] + #[ntex::test] async fn body_skips_empty_chunks() { let mut body = BodyStream::new(stream::iter( ["1", "", "2"] @@ -780,7 +766,7 @@ mod tests { ); } - #[crate::rt_test] + #[ntex::test] async fn sized_skips_empty_chunks() { let mut body = SizedStream::new( 2, diff --git a/ntex-http/src/lib.rs b/ntex-http/src/lib.rs index faefe9b3..88a962f3 100644 --- a/ntex-http/src/lib.rs +++ b/ntex-http/src/lib.rs @@ -1,6 +1,7 @@ //! Http protocol support. #![deny(rust_2018_idioms, unreachable_pub, missing_debug_implementations)] +pub mod body; pub mod error; mod map; mod serde; diff --git a/ntex-http/src/map.rs b/ntex-http/src/map.rs index 53124b55..16e0fcd8 100644 --- a/ntex-http/src/map.rs +++ b/ntex-http/src/map.rs @@ -354,13 +354,13 @@ impl AsName for HeaderName { } } -impl<'a> AsName for &'a HeaderName { +impl AsName for &HeaderName { fn as_name(&self) -> Either<&HeaderName, &str> { Either::Left(self) } } -impl<'a> AsName for &'a str { +impl AsName for &str { fn as_name(&self) -> Either<&HeaderName, &str> { Either::Right(self) } @@ -372,7 +372,7 @@ impl AsName for String { } } -impl<'a> AsName for &'a String { +impl AsName for &String { fn as_name(&self) -> Either<&HeaderName, &str> { Either::Right(self.as_str()) } diff --git a/ntex-http/src/serde.rs b/ntex-http/src/serde.rs index 88ae23e0..b282b3b7 100644 --- a/ntex-http/src/serde.rs +++ b/ntex-http/src/serde.rs @@ -158,7 +158,7 @@ impl<'de> Deserialize<'de> for HeaderValue { struct HeaderValueVisitor; -impl<'de> Visitor<'de> for HeaderValueVisitor { +impl Visitor<'_> for HeaderValueVisitor { type Value = HeaderValue; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/ntex-http/src/value.rs b/ntex-http/src/value.rs index 27b197c6..e2db9036 100644 --- a/ntex-http/src/value.rs +++ b/ntex-http/src/value.rs @@ -641,14 +641,14 @@ impl PartialOrd for String { } } -impl<'a> PartialEq for &'a HeaderValue { +impl PartialEq for &HeaderValue { #[inline] fn eq(&self, other: &HeaderValue) -> bool { **self == *other } } -impl<'a> PartialOrd for &'a HeaderValue { +impl PartialOrd for &HeaderValue { #[inline] fn partial_cmp(&self, other: &HeaderValue) -> Option { (**self).partial_cmp(other) @@ -675,14 +675,14 @@ where } } -impl<'a> PartialEq for &'a str { +impl PartialEq for &str { #[inline] fn eq(&self, other: &HeaderValue) -> bool { *other == *self } } -impl<'a> PartialOrd for &'a str { +impl PartialOrd for &str { #[inline] fn partial_cmp(&self, other: &HeaderValue) -> Option { self.as_bytes().partial_cmp(other.as_bytes()) diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index 9ab1f0ac..c109a752 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -1,5 +1,33 @@ # Changes +## [2.11.1] - 2025-03-20 + +* Add readiness check support + +## [2.11.0] - 2025-03-10 + +* Add single io context + +## [2.10.0] - 2025-02-26 + +* Impl Filter for Sealed #506 + +## [2.9.3] - 2025-01-21 + +* Allow to access io write destination buffer + +## [2.9.2] - 2024-12-05 + +* Better error handling + +## [2.9.1] - 2024-12-04 + +* Check service readiness for every turn + +## [2.9.0] - 2024-12-04 + +* Use updated Service trait + ## [2.8.3] - 2024-11-10 * Check service readiness once per decoded item diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml index 965a9029..f55aa5d0 100644 --- a/ntex-io/Cargo.toml +++ b/ntex-io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-io" -version = "2.8.3" +version = "2.11.1" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] @@ -18,9 +18,8 @@ path = "src/lib.rs" [dependencies] ntex-codec = "0.6" ntex-bytes = "0.1" -ntex-util = "2.5" -ntex-service = "3.3.3" -ntex-rt = "0.4" +ntex-util = "2.8" +ntex-service = "3.4" bitflags = "2" log = "0.4" @@ -29,4 +28,3 @@ pin-project-lite = "0.2" [dev-dependencies] ntex = "2" rand = "0.8" -env_logger = "0.11" diff --git a/ntex-io/src/buf.rs b/ntex-io/src/buf.rs index 7d4624f0..e3d701d3 100644 --- a/ntex-io/src/buf.rs +++ b/ntex-io/src/buf.rs @@ -152,6 +152,27 @@ impl Stack { } } + pub(crate) fn with_read_source(&self, io: &IoRef, f: F) -> R + where + F: FnOnce(&mut BytesVec) -> R, + { + let item = self.get_last_level(); + let mut rb = item.0.take(); + if rb.is_none() { + rb = Some(io.memory_pool().get_read_buf()); + } + + let result = f(rb.as_mut().unwrap()); + if let Some(b) = rb { + if b.is_empty() { + io.memory_pool().release_read_buf(b); + } else { + item.0.set(Some(b)); + } + } + result + } + pub(crate) fn with_read_destination(&self, io: &IoRef, f: F) -> R where F: FnOnce(&mut BytesVec) -> R, @@ -218,12 +239,12 @@ impl Stack { pub(crate) fn with_write_destination(&self, io: &IoRef, f: F) -> R where - F: FnOnce(&mut Option) -> R, + F: FnOnce(Option<&mut BytesVec>) -> R, { let item = self.get_last_level(); let mut wb = item.1.take(); - let result = f(&mut wb); + let result = f(wb.as_mut()); // check nested updates if item.1.take().is_some() { @@ -300,7 +321,7 @@ pub struct ReadBuf<'a> { pub(crate) need_write: Cell, } -impl<'a> ReadBuf<'a> { +impl ReadBuf<'_> { #[inline] /// Get io tag pub fn tag(&self) -> &'static str { @@ -444,7 +465,7 @@ pub struct WriteBuf<'a> { pub(crate) need_write: Cell, } -impl<'a> WriteBuf<'a> { +impl WriteBuf<'_> { #[inline] /// Get io tag pub fn tag(&self) -> &'static str { diff --git a/ntex-io/src/dispatcher.rs b/ntex-io/src/dispatcher.rs index e2c1de36..4c03d312 100644 --- a/ntex-io/src/dispatcher.rs +++ b/ntex-io/src/dispatcher.rs @@ -1,7 +1,7 @@ //! Framed transport dispatcher #![allow(clippy::let_underscore_future)] use std::task::{ready, Context, Poll}; -use std::{cell::Cell, future::poll_fn, future::Future, pin::Pin, rc::Rc}; +use std::{cell::Cell, future::Future, pin::Pin, rc::Rc}; use ntex_codec::{Decoder, Encoder}; use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service}; @@ -131,7 +131,6 @@ bitflags::bitflags! { const KA_ENABLED = 0b0000100; const KA_TIMEOUT = 0b0001000; const READ_TIMEOUT = 0b0010000; - const READY_TASK = 0b1000000; } } @@ -161,7 +160,6 @@ where service: PipelineBinding>, error: Cell::Error>>>, inflight: Cell, - ready: Cell, } #[derive(Copy, Clone, Debug)] @@ -223,7 +221,6 @@ where codec, error: Cell::new(None), inflight: Cell::new(0), - ready: Cell::new(false), service: Pipeline::new(service.into_service()).bind(), }); @@ -284,12 +281,6 @@ where } } - // ready task - if slf.flags.contains(Flags::READY_TASK) { - slf.flags.insert(Flags::READY_TASK); - ntex_rt::spawn(not_ready(slf.shared.clone())); - } - loop { match slf.st { DispatcherState::Processing => { @@ -350,7 +341,6 @@ where PollService::Continue => continue, }; - slf.shared.ready.set(false); slf.call_service(cx, item); } // handle write back-pressure @@ -480,16 +470,9 @@ where } fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.shared.ready.get() { - return Poll::Ready(self.check_error()); - } - // wait until service becomes ready match self.shared.service.poll_ready(cx) { - Poll::Ready(Ok(_)) => { - self.shared.ready.set(true); - Poll::Ready(self.check_error()) - } + Poll::Ready(Ok(_)) => Poll::Ready(self.check_error()), // pause io read task Poll::Pending => { log::trace!( @@ -628,30 +611,6 @@ where } } -async fn not_ready(slf: Rc>) -where - S: Service, Response = Option>> + 'static, - U: Encoder + Decoder + 'static, -{ - let pl = slf.service.clone(); - loop { - if !pl.is_shutdown() { - if let Err(err) = poll_fn(|cx| pl.poll_ready(cx)).await { - log::trace!("{}: Service readiness check failed, stopping", slf.io.tag()); - slf.error.set(Some(DispatcherError::Service(err))); - break; - } - if !pl.is_shutdown() { - poll_fn(|cx| pl.poll_not_ready(cx)).await; - slf.ready.set(false); - slf.io.wake(); - continue; - } - } - break; - } -} - #[cfg(test)] mod tests { use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex}; @@ -751,7 +710,6 @@ mod tests { io: state.into(), error: Cell::new(None), inflight: Cell::new(0), - ready: Cell::new(false), service: Pipeline::new(service).bind(), }); @@ -902,8 +860,6 @@ mod tests { Err("test") } - async fn not_ready(&self) {} - async fn call( &self, _: DispatchItem, @@ -1288,6 +1244,8 @@ mod tests { sleep(Millis(50)).await; if let DispatchItem::Item(msg) = msg { Ok::<_, ()>(Some(msg.freeze())) + } else if let DispatchItem::Disconnect(_) = msg { + Ok::<_, ()>(None) } else { panic!() } diff --git a/ntex-io/src/flags.rs b/ntex-io/src/flags.rs index bc9b5aac..1e65d2a7 100644 --- a/ntex-io/src/flags.rs +++ b/ntex-io/src/flags.rs @@ -25,6 +25,8 @@ bitflags::bitflags! { /// write task paused const WR_PAUSED = 0b0000_0100_0000_0000; + /// wait for write completion task + const WR_TASK_WAIT = 0b0000_1000_0000_0000; /// dispatcher is marked stopped const DSP_STOP = 0b0001_0000_0000_0000; @@ -38,6 +40,10 @@ impl Flags { self.intersects(Flags::IO_STOPPED) } + pub(crate) fn is_task_waiting_for_write(&self) -> bool { + self.contains(Flags::WR_TASK_WAIT) + } + pub(crate) fn is_waiting_for_write(&self) -> bool { self.intersects(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE) } @@ -46,10 +52,18 @@ impl Flags { self.remove(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE); } + pub(crate) fn task_waiting_for_write_is_done(&mut self) { + self.remove(Flags::WR_TASK_WAIT); + } + pub(crate) fn is_read_buf_ready(&self) -> bool { self.contains(Flags::BUF_R_READY) } + pub(crate) fn is_waiting_for_read(&self) -> bool { + self.contains(Flags::RD_NOTIFY) + } + pub(crate) fn cannot_read(self) -> bool { self.intersects(Flags::RD_PAUSED | Flags::BUF_R_FULL) } diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index 19cd2d6f..498e249d 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -10,7 +10,7 @@ use ntex_util::{future::Either, task::LocalWaker, time::Seconds}; use crate::buf::Stack; use crate::filter::{Base, Filter, Layer, NullFilter}; use crate::flags::Flags; -use crate::seal::Sealed; +use crate::seal::{IoBoxed, Sealed}; use crate::tasks::{ReadContext, WriteContext}; use crate::timer::TimerHandle; use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError}; @@ -80,6 +80,23 @@ impl IoState { } } + /// Get current io error + pub(super) fn error(&self) -> Option { + if let Some(err) = self.error.take() { + self.error + .set(Some(io::Error::new(err.kind(), format!("{}", err)))); + Some(err) + } else { + None + } + } + + /// Get current io result + pub(super) fn error_or_disconnected(&self) -> io::Error { + self.error() + .unwrap_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "Disconnected")) + } + pub(super) fn io_stopped(&self, err: Option) { if err.is_some() { self.error.set(err); @@ -257,19 +274,6 @@ impl Io { fn io_ref(&self) -> &IoRef { unsafe { &*self.0.get() } } - - /// Get current io error - fn error(&self) -> Option { - self.st().error.take() - } - - /// Get current io error - fn error_or_disconnected(&self) -> io::Error { - self.st() - .error - .take() - .unwrap_or_else(|| io::Error::new(io::ErrorKind::Other, "Disconnected")) - } } impl Io> { @@ -290,6 +294,12 @@ impl Io { Io(UnsafeCell::new(state), marker::PhantomData) } + #[inline] + /// Convert current io stream into boxed version + pub fn boxed(self) -> IoBoxed { + self.seal().into() + } + #[inline] /// Map current filter with new one pub fn add_filter(self, nf: U) -> Io> @@ -333,7 +343,7 @@ impl Io { "Timeout", ))), Err(RecvError::Stop) => Err(Either::Right(io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::UnexpectedEof, "Dispatcher stopped", ))), Err(RecvError::WriteBackpressure) => { @@ -423,11 +433,11 @@ impl Io { let mut flags = st.flags.get(); if flags.is_stopped() { - Poll::Ready(Err(self.error_or_disconnected())) + Poll::Ready(Err(st.error_or_disconnected())) } else { st.dispatch_task.register(cx.waker()); - let ready = flags.contains(Flags::BUF_R_READY); + let ready = flags.is_read_buf_ready(); if flags.cannot_read() { flags.cleanup_read_flags(); st.read_task.wake(); @@ -511,7 +521,7 @@ impl Io { let st = self.st(); let flags = st.flags.get(); if flags.is_stopped() { - Err(RecvError::PeerGone(self.error())) + Err(RecvError::PeerGone(st.error())) } else if flags.contains(Flags::DSP_STOP) { st.remove_flags(Flags::DSP_STOP); Err(RecvError::Stop) @@ -545,27 +555,31 @@ impl Io { /// otherwise wake up when size of write buffer is lower than /// buffer max size. pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll> { + let st = self.st(); let flags = self.flags(); - if flags.is_stopped() { - Poll::Ready(Err(self.error_or_disconnected())) - } else { - let st = self.st(); - let len = st.buffer.write_destination_size(); - if len > 0 { - if full { - st.insert_flags(Flags::BUF_W_MUST_FLUSH); - st.dispatch_task.register(cx.waker()); - return Poll::Pending; - } else if len >= st.pool.get().write_params_high() << 1 { - st.insert_flags(Flags::BUF_W_BACKPRESSURE); - st.dispatch_task.register(cx.waker()); - return Poll::Pending; - } + let len = st.buffer.write_destination_size(); + if len > 0 { + if full { + st.insert_flags(Flags::BUF_W_MUST_FLUSH); + st.dispatch_task.register(cx.waker()); + return if flags.is_stopped() { + Poll::Ready(Err(st.error_or_disconnected())) + } else { + Poll::Pending + }; + } else if len >= st.pool.get().write_params_high() << 1 { + st.insert_flags(Flags::BUF_W_BACKPRESSURE); + st.dispatch_task.register(cx.waker()); + return if flags.is_stopped() { + Poll::Ready(Err(st.error_or_disconnected())) + } else { + Poll::Pending + }; } - st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE); - Poll::Ready(Ok(())) } + st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE); + Poll::Ready(Ok(())) } #[inline] @@ -575,7 +589,7 @@ impl Io { let flags = st.flags.get(); if flags.is_stopped() { - if let Some(err) = self.error() { + if let Some(err) = st.error() { Poll::Ready(Err(err)) } else { Poll::Ready(Ok(())) @@ -611,7 +625,7 @@ impl Io { let st = self.st(); let flags = st.flags.get(); if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) { - Poll::Ready(IoStatusUpdate::PeerGone(self.error())) + Poll::Ready(IoStatusUpdate::PeerGone(st.error())) } else if flags.contains(Flags::DSP_STOP) { st.remove_flags(Flags::DSP_STOP); Poll::Ready(IoStatusUpdate::Stop) diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index 054a1906..4e9f455b 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -191,7 +191,7 @@ impl IoRef { F: FnOnce(&mut BytesVec) -> R, { if self.0.flags.get().contains(Flags::IO_STOPPED) { - Err(io::Error::new(io::ErrorKind::Other, "Disconnected")) + Err(self.0.error_or_disconnected()) } else { let result = self.0.buffer.with_write_source(self, f); self.0.filter().process_write_buf(self, &self.0.buffer, 0)?; @@ -199,6 +199,16 @@ impl IoRef { } } + #[doc(hidden)] + #[inline] + /// Get mut access to destination write buffer + pub fn with_write_dest_buf(&self, f: F) -> R + where + F: FnOnce(Option<&mut BytesVec>) -> R, + { + self.0.buffer.with_write_destination(self, f) + } + #[inline] /// Get mut access to source read buffer pub fn with_read_buf(&self, f: F) -> R @@ -559,6 +569,10 @@ mod tests { assert_eq!(in_bytes.get(), BIN.len() * 2); assert_eq!(out_bytes.get(), 8); + assert_eq!( + state.with_write_dest_buf(|b| b.map(|b| b.len()).unwrap_or(0)), + 0 + ); // refs assert_eq!(Rc::strong_count(&in_bytes), 3); diff --git a/ntex-io/src/lib.rs b/ntex-io/src/lib.rs index cbfde011..6d4b6bdd 100644 --- a/ntex-io/src/lib.rs +++ b/ntex-io/src/lib.rs @@ -29,7 +29,7 @@ pub use self::filter::{Base, Filter, Layer}; pub use self::framed::Framed; pub use self::io::{Io, IoRef, OnDisconnect}; pub use self::seal::{IoBoxed, Sealed}; -pub use self::tasks::{ReadContext, WriteContext, WriteContextBuf}; +pub use self::tasks::{IoContext, ReadContext, WriteContext, WriteContextBuf}; pub use self::timer::TimerHandle; pub use self::utils::{seal, Decoded}; @@ -53,7 +53,9 @@ pub trait AsyncWrite { /// Status for read task #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub enum ReadStatus { + /// Read task is clear to proceed with read operation Ready, + /// Terminate read task Terminate, } diff --git a/ntex-io/src/seal.rs b/ntex-io/src/seal.rs index 0dbb8886..28dac673 100644 --- a/ntex-io/src/seal.rs +++ b/ntex-io/src/seal.rs @@ -1,6 +1,7 @@ -use std::{fmt, ops}; +use std::{any::Any, any::TypeId, fmt, io, ops, task::Context, task::Poll}; -use crate::{filter::Filter, Io}; +use crate::filter::{Filter, FilterReadStatus}; +use crate::{buf::Stack, Io, IoRef, ReadStatus, WriteStatus}; /// Sealed filter type pub struct Sealed(pub(crate) Box); @@ -11,6 +12,44 @@ impl fmt::Debug for Sealed { } } +impl Filter for Sealed { + #[inline] + fn query(&self, id: TypeId) -> Option> { + self.0.query(id) + } + + #[inline] + fn process_read_buf( + &self, + io: &IoRef, + stack: &Stack, + idx: usize, + nbytes: usize, + ) -> io::Result { + self.0.process_read_buf(io, stack, idx, nbytes) + } + + #[inline] + fn process_write_buf(&self, io: &IoRef, stack: &Stack, idx: usize) -> io::Result<()> { + self.0.process_write_buf(io, stack, idx) + } + + #[inline] + fn shutdown(&self, io: &IoRef, stack: &Stack, idx: usize) -> io::Result> { + self.0.shutdown(io, stack, idx) + } + + #[inline] + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { + self.0.poll_read_ready(cx) + } + + #[inline] + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { + self.0.poll_write_ready(cx) + } +} + #[derive(Debug)] /// Boxed `Io` object with erased filter type pub struct IoBoxed(Io); @@ -25,12 +64,6 @@ impl IoBoxed { } } -impl From> for IoBoxed { - fn from(io: Io) -> Self { - Self(io) - } -} - impl From> for IoBoxed { fn from(io: Io) -> Self { Self(io.seal()) @@ -45,3 +78,9 @@ impl ops::Deref for IoBoxed { &self.0 } } + +impl From for Io { + fn from(value: IoBoxed) -> Self { + value.0 + } +} diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index 497e1f6c..55f99416 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -1,6 +1,6 @@ -use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll}; +use std::{cell::Cell, fmt, future::poll_fn, io, task::ready, task::Context, task::Poll}; -use ntex_bytes::{BufMut, BytesVec}; +use ntex_bytes::{Buf, BufMut, BytesVec}; use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep}; use crate::{AsyncRead, AsyncWrite, Flags, IoRef, ReadStatus, WriteStatus}; @@ -19,6 +19,13 @@ impl ReadContext { Self(io.clone(), Cell::new(None)) } + #[doc(hidden)] + #[inline] + /// Io tag + pub fn context(&self) -> IoContext { + IoContext::new(&self.0) + } + #[inline] /// Io tag pub fn tag(&self) -> &'static str { @@ -87,7 +94,7 @@ impl ReadContext { // handle incoming data let total2 = buf.len(); - let nbytes = if total2 > total { total2 - total } else { 0 }; + let nbytes = total2.saturating_sub(total); let total = total2; if let Some(mut first_buf) = inner.buffer.get_read_source() { @@ -121,7 +128,7 @@ impl ReadContext { ); // dest buffer has new data, wake up dispatcher inner.dispatch_task.wake(); - } else if inner.flags.get().contains(Flags::RD_NOTIFY) { + } else if inner.flags.get().is_waiting_for_read() { // in case of "notify" we must wake up dispatch task // if we read any data from source inner.dispatch_task.wake(); @@ -342,3 +349,604 @@ impl WriteContextBuf { } } } + +/// Context for io read task +pub struct IoContext(IoRef); + +impl fmt::Debug for IoContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("IoContext").field("io", &self.0).finish() + } +} + +impl IoContext { + pub(crate) fn new(io: &IoRef) -> Self { + Self(io.clone()) + } + + #[inline] + /// Io tag + pub fn tag(&self) -> &'static str { + self.0.tag() + } + + #[doc(hidden)] + /// Io flags + pub fn flags(&self) -> crate::flags::Flags { + self.0.flags() + } + + #[inline] + /// Check readiness for read operations + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { + self.shutdown_filters(); + self.0.filter().poll_read_ready(cx) + } + + #[inline] + /// Check readiness for write operations + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { + self.0.filter().poll_write_ready(cx) + } + + #[inline] + /// Get io error + pub fn stopped(&self, e: Option) { + self.0 .0.io_stopped(e); + } + + /// Wait when io get closed or preparing for close + pub async fn shutdown(&self, flush_buf: bool) { + let st = &self.0 .0; + let mut timeout = None; + + poll_fn(|cx| { + let flags = self.0.flags(); + + if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) { + Poll::Ready(()) + } else { + st.write_task.register(cx.waker()); + if flags.contains(Flags::IO_STOPPING_FILTERS) { + if timeout.is_none() { + timeout = Some(sleep(st.disconnect_timeout.get())); + } + if timeout.as_ref().unwrap().poll_elapsed(cx).is_ready() { + st.dispatch_task.wake(); + st.insert_flags(Flags::IO_STOPPING); + return Poll::Ready(()); + } + } + Poll::Pending + } + }) + .await; + + if flush_buf && !self.0.flags().contains(Flags::WR_PAUSED) { + st.insert_flags(Flags::WR_TASK_WAIT); + + poll_fn(|cx| { + let flags = self.0.flags(); + + if flags.intersects(Flags::WR_PAUSED | Flags::IO_STOPPED) { + Poll::Ready(()) + } else { + st.write_task.register(cx.waker()); + + if timeout.is_none() { + timeout = Some(sleep(st.disconnect_timeout.get())); + } + if timeout.as_ref().unwrap().poll_elapsed(cx).is_ready() { + Poll::Ready(()) + } else { + Poll::Pending + } + } + }) + .await; + } + } + + /// Get read buffer + pub fn get_read_buf(&self) -> Poll { + let inner = &self.0 .0; + + if let Some(waker) = inner.read_task.take() { + let mut cx = Context::from_waker(&waker); + + if let Poll::Ready(ReadStatus::Ready) = self.0.filter().poll_read_ready(&mut cx) + { + let mut buf = if inner.flags.get().is_read_buf_ready() { + // read buffer is still not read by dispatcher + // we cannot touch it + inner.pool.get().get_read_buf() + } else { + inner + .buffer + .get_read_source() + .unwrap_or_else(|| inner.pool.get().get_read_buf()) + }; + + // make sure we've got room + let (hw, lw) = self.0.memory_pool().read_params().unpack(); + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); + } + return Poll::Ready(buf); + } + } + + Poll::Pending + } + + pub fn release_read_buf(&self, buf: BytesVec) { + let inner = &self.0 .0; + if let Some(mut first_buf) = inner.buffer.get_read_source() { + first_buf.extend_from_slice(&buf); + inner.buffer.set_read_source(&self.0, first_buf); + } else { + inner.buffer.set_read_source(&self.0, buf); + } + } + + /// Set read buffer + pub fn set_read_buf(&self, result: io::Result, buf: BytesVec) -> Poll<()> { + let inner = &self.0 .0; + let (hw, _) = self.0.memory_pool().read_params().unpack(); + + if let Some(mut first_buf) = inner.buffer.get_read_source() { + first_buf.extend_from_slice(&buf); + inner.buffer.set_read_source(&self.0, first_buf); + } else { + inner.buffer.set_read_source(&self.0, buf); + } + + match result { + Ok(0) => { + inner.io_stopped(None); + Poll::Ready(()) + } + Ok(nbytes) => { + let filter = self.0.filter(); + let res = filter + .process_read_buf(&self.0, &inner.buffer, 0, nbytes) + .and_then(|status| { + if status.nbytes > 0 { + // dest buffer has new data, wake up dispatcher + if inner.buffer.read_destination_size() >= hw { + log::trace!( + "{}: Io read buffer is too large {}, enable read back-pressure", + self.0.tag(), + nbytes + ); + inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL); + } else { + inner.insert_flags(Flags::BUF_R_READY); + + if nbytes >= hw { + // read task is paused because of read back-pressure + // but there is no new data in top most read buffer + // so we need to wake up read task to read more data + // otherwise read task would sleep forever + inner.read_task.wake(); + } + } + log::trace!( + "{}: New {} bytes available, wakeup dispatcher", + self.0.tag(), + nbytes + ); + if !inner.dispatch_task.wake_checked() { + log::error!("Dispatcher waker is not registered"); + } + } else { + if nbytes >= hw { + // read task is paused because of read back-pressure + // but there is no new data in top most read buffer + // so we need to wake up read task to read more data + // otherwise read task would sleep forever + inner.read_task.wake(); + } + if inner.flags.get().is_waiting_for_read() { + // in case of "notify" we must wake up dispatch task + // if we read any data from source + inner.dispatch_task.wake(); + } + } + + // while reading, filter wrote some data + // in that case filters need to process write buffers + // and potentialy wake write task + if status.need_write { + inner.write_task.wake(); + filter.process_write_buf(&self.0, &inner.buffer, 0) + } else { + Ok(()) + } + }); + + if let Err(err) = res { + inner.io_stopped(Some(err)); + Poll::Ready(()) + } else { + self.shutdown_filters(); + Poll::Pending + } + } + Err(e) => { + inner.io_stopped(Some(e)); + Poll::Ready(()) + } + } + } + + /// Get write buffer + pub fn get_write_buf(&self) -> Poll { + let inner = &self.0 .0; + + // check write readiness + if let Some(waker) = inner.write_task.take() { + let ready = self + .0 + .filter() + .poll_write_ready(&mut Context::from_waker(&waker)); + let buf = if matches!( + ready, + Poll::Ready(WriteStatus::Ready | WriteStatus::Shutdown) + ) { + inner.buffer.get_write_destination().and_then(|buf| { + if buf.is_empty() { + None + } else { + Some(buf) + } + }) + } else { + None + }; + + if let Some(buf) = buf { + return Poll::Ready(buf); + } + } + Poll::Pending + } + + pub fn release_write_buf(&self, mut buf: BytesVec) { + let inner = &self.0 .0; + + if let Some(b) = inner.buffer.get_write_destination() { + buf.extend_from_slice(&b); + self.0.memory_pool().release_write_buf(b); + } + inner.buffer.set_write_destination(buf); + + // if write buffer is smaller than high watermark value, turn off back-pressure + let len = inner.buffer.write_destination_size(); + let mut flags = inner.flags.get(); + + if len == 0 { + if flags.is_waiting_for_write() { + flags.waiting_for_write_is_done(); + inner.dispatch_task.wake(); + } + flags.insert(Flags::WR_PAUSED); + inner.flags.set(flags); + } else if flags.contains(Flags::BUF_W_BACKPRESSURE) + && len < inner.pool.get().write_params_high() << 1 + { + flags.remove(Flags::BUF_W_BACKPRESSURE); + inner.flags.set(flags); + inner.dispatch_task.wake(); + } + inner.flags.set(flags); + } + + /// Set write buffer + pub fn set_write_buf(&self, result: io::Result, mut buf: BytesVec) -> Poll<()> { + let result = match result { + Ok(0) => { + log::trace!("{}: Disconnected during flush", self.tag()); + Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + )) + } + Ok(n) => { + if n == buf.len() { + buf.clear(); + Ok(0) + } else { + buf.advance(n); + Ok(buf.len()) + } + } + Err(e) => Err(e), + }; + + let inner = &self.0 .0; + + // set buffer back + let result = match result { + Ok(0) => { + // log::debug!("{}: WROTE ALL {:?}", self.0.tag(), inner.buffer.write_destination_size()); + self.0.memory_pool().release_write_buf(buf); + Ok(inner.buffer.write_destination_size()) + } + Ok(_) => { + if let Some(b) = inner.buffer.get_write_destination() { + buf.extend_from_slice(&b); + self.0.memory_pool().release_write_buf(b); + } + let l = buf.len(); + // log::debug!("{}: WROTE SOME {:?}", self.0.tag(), l); + inner.buffer.set_write_destination(buf); + Ok(l) + } + Err(e) => Err(e), + }; + + let mut flags = inner.flags.get(); + match result { + Ok(0) => { + // all data has been written + flags.insert(Flags::WR_PAUSED); + + if flags.is_task_waiting_for_write() { + flags.task_waiting_for_write_is_done(); + inner.write_task.wake(); + } + + if flags.is_waiting_for_write() { + flags.waiting_for_write_is_done(); + inner.dispatch_task.wake(); + } + inner.flags.set(flags); + Poll::Ready(()) + } + Ok(len) => { + // if write buffer is smaller than high watermark value, turn off back-pressure + if flags.contains(Flags::BUF_W_BACKPRESSURE) + && len < inner.pool.get().write_params_high() << 1 + { + flags.remove(Flags::BUF_W_BACKPRESSURE); + inner.flags.set(flags); + inner.dispatch_task.wake(); + } + Poll::Pending + } + Err(e) => { + inner.io_stopped(Some(e)); + Poll::Ready(()) + } + } + } + + /// Get read buffer + pub fn is_read_ready(&self) -> bool { + // check read readiness + if let Some(waker) = self.0 .0.read_task.take() { + let mut cx = Context::from_waker(&waker); + + if let Poll::Ready(ReadStatus::Ready) = self.0.filter().poll_read_ready(&mut cx) + { + return true; + } + } + false + } + + pub fn with_read_buf(&self, f: F) -> Poll<()> + where + F: FnOnce(&mut BytesVec) -> Poll>, + { + let inner = &self.0 .0; + let (hw, lw) = self.0.memory_pool().read_params().unpack(); + let result = inner.buffer.with_read_source(&self.0, |buf| { + // make sure we've got room + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); + } + + f(buf) + }); + + // handle buffer changes + match result { + Poll::Ready(Ok(0)) => { + inner.io_stopped(None); + Poll::Ready(()) + } + Poll::Ready(Ok(nbytes)) => { + let filter = self.0.filter(); + let _ = filter + .process_read_buf(&self.0, &inner.buffer, 0, nbytes) + .and_then(|status| { + if status.nbytes > 0 { + // dest buffer has new data, wake up dispatcher + if inner.buffer.read_destination_size() >= hw { + log::trace!( + "{}: Io read buffer is too large {}, enable read back-pressure", + self.0.tag(), + nbytes + ); + inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL); + } else { + inner.insert_flags(Flags::BUF_R_READY); + + if nbytes >= hw { + // read task is paused because of read back-pressure + // but there is no new data in top most read buffer + // so we need to wake up read task to read more data + // otherwise read task would sleep forever + inner.read_task.wake(); + } + } + log::trace!( + "{}: New {} bytes available, wakeup dispatcher", + self.0.tag(), + nbytes + ); + if !inner.dispatch_task.wake_checked() { + log::error!("Dispatcher waker is not registered"); + } + } else { + if nbytes >= hw { + // read task is paused because of read back-pressure + // but there is no new data in top most read buffer + // so we need to wake up read task to read more data + // otherwise read task would sleep forever + inner.read_task.wake(); + } + if inner.flags.get().is_waiting_for_read() { + // in case of "notify" we must wake up dispatch task + // if we read any data from source + inner.dispatch_task.wake(); + } + } + + // while reading, filter wrote some data + // in that case filters need to process write buffers + // and potentialy wake write task + if status.need_write { + filter.process_write_buf(&self.0, &inner.buffer, 0) + } else { + Ok(()) + } + }) + .map_err(|err| { + inner.dispatch_task.wake(); + inner.io_stopped(Some(err)); + inner.insert_flags(Flags::BUF_R_READY); + }); + Poll::Pending + } + Poll::Ready(Err(e)) => { + inner.io_stopped(Some(e)); + Poll::Ready(()) + } + Poll::Pending => { + self.shutdown_filters(); + Poll::Pending + } + } + } + + /// Get write buffer + pub fn with_write_buf(&self, f: F) -> Poll<()> + where + F: FnOnce(&BytesVec) -> Poll>, + { + let inner = &self.0 .0; + let result = inner.buffer.with_write_destination(&self.0, |buf| { + let Some(buf) = + buf.and_then(|buf| if buf.is_empty() { None } else { Some(buf) }) + else { + return Poll::Ready(Ok(0)); + }; + + match ready!(f(buf)) { + Ok(0) => { + log::trace!("{}: Disconnected during flush", self.tag()); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + ))) + } + Ok(n) => { + if n == buf.len() { + buf.clear(); + Poll::Ready(Ok(0)) + } else { + buf.advance(n); + Poll::Ready(Ok(buf.len())) + } + } + Err(e) => Poll::Ready(Err(e)), + } + }); + + let mut flags = inner.flags.get(); + + let result = match result { + Poll::Pending => { + flags.remove(Flags::WR_PAUSED); + Poll::Pending + } + Poll::Ready(Ok(0)) => { + // all data has been written + flags.insert(Flags::WR_PAUSED); + + if flags.is_task_waiting_for_write() { + flags.task_waiting_for_write_is_done(); + inner.write_task.wake(); + } + + if flags.is_waiting_for_write() { + flags.waiting_for_write_is_done(); + inner.dispatch_task.wake(); + } + Poll::Ready(()) + } + Poll::Ready(Ok(len)) => { + // if write buffer is smaller than high watermark value, turn off back-pressure + if flags.contains(Flags::BUF_W_BACKPRESSURE) + && len < inner.pool.get().write_params_high() << 1 + { + flags.remove(Flags::BUF_W_BACKPRESSURE); + inner.dispatch_task.wake(); + } + Poll::Pending + } + Poll::Ready(Err(e)) => { + self.0 .0.io_stopped(Some(e)); + Poll::Ready(()) + } + }; + + inner.flags.set(flags); + result + } + + fn shutdown_filters(&self) { + let io = &self.0; + let st = &self.0 .0; + if st.flags.get().contains(Flags::IO_STOPPING_FILTERS) { + let flags = st.flags.get(); + + if !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) { + let filter = io.filter(); + match filter.shutdown(io, &st.buffer, 0) { + Ok(Poll::Ready(())) => { + st.dispatch_task.wake(); + st.insert_flags(Flags::IO_STOPPING); + } + Ok(Poll::Pending) => { + // check read buffer, if buffer is not consumed it is unlikely + // that filter will properly complete shutdown + if flags.contains(Flags::RD_PAUSED) + || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY) + { + st.dispatch_task.wake(); + st.insert_flags(Flags::IO_STOPPING); + } + } + Err(err) => { + st.io_stopped(Some(err)); + } + } + if let Err(err) = filter.process_write_buf(io, &st.buffer, 0) { + st.io_stopped(Some(err)); + } + } + } + } +} + +impl Clone for IoContext { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} diff --git a/ntex-io/src/utils.rs b/ntex-io/src/utils.rs index c2cd2bbb..d5f0d563 100644 --- a/ntex-io/src/utils.rs +++ b/ntex-io/src/utils.rs @@ -27,7 +27,7 @@ where S: ServiceFactory, C: Clone, { - chain_factory(fn_service(|io: Io| Ready::Ok(IoBoxed::from(io)))) + chain_factory(fn_service(|io: Io| Ready::Ok(io.boxed()))) .map_init_err(|_| panic!()) .and_then(srv) } diff --git a/ntex-macros/CHANGES.md b/ntex-macros/CHANGES.md index 8f29a797..8d884c08 100644 --- a/ntex-macros/CHANGES.md +++ b/ntex-macros/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.1.4] - 2025-03-14 + +* Enable env_logger for test macro + ## [0.1.2] - 2021-02-25 * Export runtime from ntex crate diff --git a/ntex-macros/Cargo.toml b/ntex-macros/Cargo.toml index 512b4501..a5bcf67d 100644 --- a/ntex-macros/Cargo.toml +++ b/ntex-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-macros" -version = "0.1.3" +version = "0.1.4" description = "ntex proc macros" readme = "README.md" authors = ["ntex contributors "] @@ -18,4 +18,3 @@ proc-macro2 = "^1" [dev-dependencies] ntex = "2" futures = "0.3" -env_logger = "0.11" diff --git a/ntex-macros/src/lib.rs b/ntex-macros/src/lib.rs index 8c6c29d2..3179296a 100644 --- a/ntex-macros/src/lib.rs +++ b/ntex-macros/src/lib.rs @@ -262,6 +262,7 @@ pub fn rt_test(_: TokenStream, item: TokenStream) -> TokenStream { quote! { #(#attrs)* fn #name() #ret { + ntex::util::enable_test_logging(); ntex::rt::System::new("test") .block_on(async { #body }) } @@ -271,6 +272,7 @@ pub fn rt_test(_: TokenStream, item: TokenStream) -> TokenStream { #[test] #(#attrs)* fn #name() #ret { + ntex::util::enable_test_logging(); ntex::rt::System::new("test") .block_on(async { #body }) } diff --git a/ntex-net/CHANGES.md b/ntex-net/CHANGES.md index 682d9160..e60744ef 100644 --- a/ntex-net/CHANGES.md +++ b/ntex-net/CHANGES.md @@ -1,5 +1,55 @@ # Changes +## [2.5.10] - 2025-03-28 + +* Better closed sockets handling + +## [2.5.9] - 2025-03-27 + +* Handle closed sockets + +## [2.5.8] - 2025-03-25 + +* Update neon runtime + +## [2.5.7] - 2025-03-21 + +* Simplify neon poll impl + +## [2.5.6] - 2025-03-20 + +* Redesign neon poll support + +## [2.5.5] - 2025-03-17 + +* Add check for required io-uring opcodes + +* Handle io-uring cancelation + +## [2.5.4] - 2025-03-15 + +* Close FD in various case for poll driver + +## [2.5.3] - 2025-03-14 + +* Fix operation cancelation handling for poll driver + +## [2.5.2] - 2025-03-14 + +* Fix operation cancelation handling for io-uring driver + +## [2.5.1] - 2025-03-14 + +* Fix socket connect for io-uring driver + +## [2.5.0] - 2025-03-12 + +* Add neon runtime support + +* Drop glommio support + +* Drop async-std support + ## [2.4.0] - 2024-09-25 * Update to glommio v0.9 diff --git a/ntex-net/Cargo.toml b/ntex-net/Cargo.toml index a26f2071..5a72d3eb 100644 --- a/ntex-net/Cargo.toml +++ b/ntex-net/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-net" -version = "2.4.0" +version = "2.5.10" authors = ["ntex contributors "] description = "ntexwork utils for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -24,28 +24,36 @@ tokio = ["ntex-rt/tokio", "ntex-tokio"] # compio runtime compio = ["ntex-rt/compio", "ntex-compio"] -# glommio runtime -glommio = ["ntex-rt/glommio", "ntex-glommio"] +# neon runtime +neon = ["ntex-rt/neon", "ntex-neon", "slab", "socket2"] -# async-std runtime -async-std = ["ntex-rt/async-std", "ntex-async-std"] +polling = ["ntex-neon/polling", "dep:polling", "socket2"] +io-uring = ["ntex-neon/io-uring", "dep:io-uring", "socket2"] [dependencies] ntex-service = "3.3" ntex-bytes = "0.1" ntex-http = "0.1" -ntex-io = "2.8" -ntex-rt = "0.4.18" +ntex-io = "2.11.1" +ntex-rt = "0.4.25" ntex-util = "2.5" ntex-tokio = { version = "0.5.3", optional = true } -ntex-compio = { version = "0.2.1", optional = true } -ntex-glommio = { version = "0.5.2", optional = true } -ntex-async-std = { version = "0.5.1", optional = true } +ntex-compio = { version = "0.2.4", optional = true } +ntex-neon = { version = "0.1.15", optional = true } -log = "0.4" -thiserror = "1" +bitflags = { workspace = true } +cfg-if = { workspace = true } +log = { workspace = true } +libc = { workspace = true } +slab = { workspace = true, optional = true } +socket2 = { workspace = true, optional = true, features = ["all"] } +thiserror = { workspace = true } + +# Linux specific dependencies +[target.'cfg(target_os = "linux")'.dependencies] +io-uring = { workspace = true, optional = true } +polling = { workspace = true, optional = true } [dev-dependencies] ntex = "2" -env_logger = "0.11" diff --git a/ntex-net/src/compat.rs b/ntex-net/src/compat.rs index ad320882..fdc84f71 100644 --- a/ntex-net/src/compat.rs +++ b/ntex-net/src/compat.rs @@ -6,63 +6,18 @@ pub use ntex_tokio::{from_tcp_stream, tcp_connect, tcp_connect_in}; #[cfg(all(unix, feature = "tokio"))] pub use ntex_tokio::{from_unix_stream, unix_connect, unix_connect_in}; -#[cfg(all( - feature = "compio", - not(feature = "tokio"), - not(feature = "async-std"), - not(feature = "glommio") -))] +#[cfg(all(feature = "compio", not(feature = "tokio"), not(feature = "neon")))] pub use ntex_compio::{from_tcp_stream, tcp_connect, tcp_connect_in}; #[cfg(all( unix, feature = "compio", not(feature = "tokio"), - not(feature = "async-std"), - not(feature = "glommio") + not(feature = "neon") ))] pub use ntex_compio::{from_unix_stream, unix_connect, unix_connect_in}; -#[cfg(all( - feature = "async-std", - not(feature = "tokio"), - not(feature = "compio"), - not(feature = "glommio") -))] -pub use ntex_async_std::{from_tcp_stream, tcp_connect, tcp_connect_in}; - -#[cfg(all( - unix, - feature = "async-std", - not(feature = "tokio"), - not(feature = "compio"), - not(feature = "glommio") -))] -pub use ntex_async_std::{from_unix_stream, unix_connect, unix_connect_in}; - -#[cfg(all( - feature = "glommio", - not(feature = "tokio"), - not(feature = "compio"), - not(feature = "async-std") -))] -pub use ntex_glommio::{from_tcp_stream, tcp_connect, tcp_connect_in}; - -#[cfg(all( - unix, - feature = "glommio", - not(feature = "tokio"), - not(feature = "compio"), - not(feature = "async-std") -))] -pub use ntex_glommio::{from_unix_stream, unix_connect, unix_connect_in}; - -#[cfg(all( - not(feature = "tokio"), - not(feature = "compio"), - not(feature = "async-std"), - not(feature = "glommio") -))] +#[cfg(all(not(feature = "tokio"), not(feature = "compio"), not(feature = "neon")))] mod no_rt { use ntex_io::Io; @@ -127,10 +82,5 @@ mod no_rt { } } -#[cfg(all( - not(feature = "tokio"), - not(feature = "compio"), - not(feature = "async-std"), - not(feature = "glommio") -))] +#[cfg(all(not(feature = "tokio"), not(feature = "compio"), not(feature = "neon")))] pub use no_rt::*; diff --git a/ntex-net/src/connect/service.rs b/ntex-net/src/connect/service.rs index 4969b9fd..9e6a0549 100644 --- a/ntex-net/src/connect/service.rs +++ b/ntex-net/src/connect/service.rs @@ -197,7 +197,7 @@ impl Future for TcpConnectorResponse { Poll::Ready(Ok(sock)) => { let req = this.req.take().unwrap(); log::trace!( - "{}: TCP connector - successfully connected to connecting to {:?} - {:?}", + "{}: TCP connector - successfully connected to {:?} - {:?}", this.tag, req.host(), sock.query::().get() diff --git a/ntex-net/src/helpers.rs b/ntex-net/src/helpers.rs new file mode 100644 index 00000000..588acf65 --- /dev/null +++ b/ntex-net/src/helpers.rs @@ -0,0 +1,86 @@ +use std::{io, net::SocketAddr, os::fd::FromRawFd, path::Path}; + +use ntex_neon::syscall; +use ntex_util::channel::oneshot::channel; +use socket2::{Protocol, SockAddr, Socket, Type}; + +pub(crate) fn pool_io_err(result: std::result::Result) -> io::Result { + result.map_err(|_| io::Error::new(io::ErrorKind::Other, "Thread pool panic")) +} + +pub(crate) async fn connect(addr: SocketAddr) -> io::Result { + let addr = SockAddr::from(addr); + let domain = addr.domain().into(); + connect_inner(addr, domain, Type::STREAM.into(), Protocol::TCP.into()).await +} + +pub(crate) async fn connect_unix(path: impl AsRef) -> io::Result { + let addr = SockAddr::unix(path)?; + connect_inner(addr, socket2::Domain::UNIX.into(), Type::STREAM.into(), 0).await +} + +async fn connect_inner( + addr: SockAddr, + domain: i32, + socket_type: i32, + protocol: i32, +) -> io::Result { + #[allow(unused_mut)] + let mut ty = socket_type; + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "hurd", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + ))] + { + ty |= libc::SOCK_CLOEXEC; + } + + let fd = ntex_rt::spawn_blocking(move || syscall!(libc::socket(domain, ty, protocol))) + .await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + .and_then(pool_io_err)?; + + let (sender, rx) = channel(); + + crate::rt_impl::connect::ConnectOps::current().connect(fd, addr, sender)?; + + rx.await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "IO Driver is gone")) + .and_then(|item| item)?; + + Ok(unsafe { Socket::from_raw_fd(fd) }) +} + +pub(crate) fn prep_socket(sock: Socket) -> io::Result { + #[cfg(not(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "hurd", + target_os = "illumos", + target_os = "linux", + target_os = "netbsd", + target_os = "openbsd", + target_os = "espidf", + target_os = "vita", + )))] + sock.set_cloexec(true)?; + #[cfg(any( + target_os = "ios", + target_os = "macos", + target_os = "tvos", + target_os = "watchos", + ))] + sock.set_nosigpipe(true)?; + sock.set_nonblocking(true)?; + + Ok(sock) +} diff --git a/ntex-net/src/lib.rs b/ntex-net/src/lib.rs index 60a57add..ddc272bb 100644 --- a/ntex-net/src/lib.rs +++ b/ntex-net/src/lib.rs @@ -1,5 +1,6 @@ //! Utility for async runtime abstraction #![deny(rust_2018_idioms, unreachable_pub, missing_debug_implementations)] +#![allow(unused_variables, dead_code)] mod compat; pub mod connect; @@ -7,4 +8,25 @@ pub mod connect; pub use ntex_io::Io; pub use ntex_rt::{spawn, spawn_blocking}; -pub use self::compat::*; +cfg_if::cfg_if! { + if #[cfg(all(feature = "neon", target_os = "linux", feature = "io-uring"))] { + #[path = "rt_uring/mod.rs"] + mod rt_impl; + pub use self::rt_impl::{ + from_tcp_stream, from_unix_stream, tcp_connect, tcp_connect_in, unix_connect, + unix_connect_in, + }; + } else if #[cfg(all(unix, feature = "neon"))] { + #[path = "rt_polling/mod.rs"] + mod rt_impl; + pub use self::rt_impl::{ + from_tcp_stream, from_unix_stream, tcp_connect, tcp_connect_in, unix_connect, + unix_connect_in, + }; + } else { + pub use self::compat::*; + } +} + +#[cfg(all(unix, feature = "neon"))] +mod helpers; diff --git a/ntex-net/src/rt_polling/connect.rs b/ntex-net/src/rt_polling/connect.rs new file mode 100644 index 00000000..8f0f1dc9 --- /dev/null +++ b/ntex-net/src/rt_polling/connect.rs @@ -0,0 +1,111 @@ +use std::os::fd::{AsRawFd, RawFd}; +use std::{cell::RefCell, io, rc::Rc, task::Poll}; + +use ntex_neon::driver::{DriverApi, Event, Handler}; +use ntex_neon::{syscall, Runtime}; +use ntex_util::channel::oneshot::Sender; +use slab::Slab; +use socket2::SockAddr; + +#[derive(Clone)] +pub(crate) struct ConnectOps(Rc); + +#[derive(Debug)] +enum Change { + Event(Event), + Error(io::Error), +} + +struct ConnectOpsBatcher { + inner: Rc, +} + +struct Item { + fd: RawFd, + sender: Sender>, +} + +struct ConnectOpsInner { + api: DriverApi, + connects: RefCell>, +} + +impl ConnectOps { + pub(crate) fn current() -> Self { + Runtime::value(|rt| { + let mut inner = None; + rt.driver().register(|api| { + let ops = Rc::new(ConnectOpsInner { + api, + connects: RefCell::new(Slab::new()), + }); + inner = Some(ops.clone()); + Box::new(ConnectOpsBatcher { inner: ops }) + }); + + ConnectOps(inner.unwrap()) + }) + } + + pub(crate) fn connect( + &self, + fd: RawFd, + addr: SockAddr, + sender: Sender>, + ) -> io::Result { + let result = syscall!(break libc::connect(fd, addr.as_ptr(), addr.len())); + + if let Poll::Ready(res) = result { + res?; + } + + let item = Item { fd, sender }; + let id = self.0.connects.borrow_mut().insert(item); + + self.0.api.attach(fd, id as u32, Some(Event::writable(0))); + Ok(id) + } +} + +impl Handler for ConnectOpsBatcher { + fn event(&mut self, id: usize, event: Event) { + log::debug!("connect-fd is readable {:?}", id); + + let mut connects = self.inner.connects.borrow_mut(); + + if connects.contains(id) { + let item = connects.remove(id); + if event.writable { + let mut err: libc::c_int = 0; + let mut err_len = std::mem::size_of::() as libc::socklen_t; + + let res = syscall!(libc::getsockopt( + item.fd.as_raw_fd(), + libc::SOL_SOCKET, + libc::SO_ERROR, + &mut err as *mut _ as *mut _, + &mut err_len + )); + + let res = if err == 0 { + res.map(|_| ()) + } else { + Err(io::Error::from_raw_os_error(err)) + }; + + self.inner.api.detach(item.fd, id as u32); + let _ = item.sender.send(res); + } + } + } + + fn error(&mut self, id: usize, err: io::Error) { + let mut connects = self.inner.connects.borrow_mut(); + + if connects.contains(id) { + let item = connects.remove(id); + let _ = item.sender.send(Err(err)); + self.inner.api.detach(item.fd, id as u32); + } + } +} diff --git a/ntex-net/src/rt_polling/driver.rs b/ntex-net/src/rt_polling/driver.rs new file mode 100644 index 00000000..24db553d --- /dev/null +++ b/ntex-net/src/rt_polling/driver.rs @@ -0,0 +1,368 @@ +use std::os::fd::{AsRawFd, RawFd}; +use std::{cell::Cell, cell::RefCell, future::Future, io, mem, rc::Rc, task, task::Poll}; + +use ntex_neon::driver::{DriverApi, Event, Handler}; +use ntex_neon::{syscall, Runtime}; +use slab::Slab; + +use ntex_bytes::BufMut; +use ntex_io::IoContext; + +pub(crate) struct StreamCtl { + id: u32, + inner: Rc>, +} + +bitflags::bitflags! { + #[derive(Copy, Clone, Debug)] + struct Flags: u8 { + const RD = 0b0000_0001; + const WR = 0b0000_0010; + } +} + +struct StreamItem { + io: Option, + fd: RawFd, + flags: Flags, + ref_count: u16, + context: IoContext, +} + +pub(crate) struct StreamOps(Rc>); + +struct StreamOpsHandler { + inner: Rc>, +} + +struct StreamOpsInner { + api: DriverApi, + delayd_drop: Cell, + feed: RefCell>, + streams: Cell>>>>, +} + +impl StreamItem { + fn tag(&self) -> &'static str { + self.context.tag() + } +} + +impl StreamOps { + pub(crate) fn current() -> Self { + Runtime::value(|rt| { + let mut inner = None; + rt.driver().register(|api| { + let ops = Rc::new(StreamOpsInner { + api, + feed: RefCell::new(Vec::new()), + delayd_drop: Cell::new(false), + streams: Cell::new(Some(Box::new(Slab::new()))), + }); + inner = Some(ops.clone()); + Box::new(StreamOpsHandler { inner: ops }) + }); + + StreamOps(inner.unwrap()) + }) + } + + pub(crate) fn register(&self, io: T, context: IoContext) -> StreamCtl { + let fd = io.as_raw_fd(); + let stream = self.0.with(move |streams| { + let item = StreamItem { + fd, + context, + io: Some(io), + ref_count: 1, + flags: Flags::empty(), + }; + StreamCtl { + id: streams.insert(item) as u32, + inner: self.0.clone(), + } + }); + + self.0.api.attach( + fd, + stream.id, + Some(Event::new(0, false, false).with_interrupt()), + ); + stream + } +} + +impl Clone for StreamOps { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Handler for StreamOpsHandler { + fn event(&mut self, id: usize, ev: Event) { + self.inner.with(|streams| { + if !streams.contains(id) { + return; + } + let item = &mut streams[id]; + if item.io.is_none() { + return; + } + log::debug!("{}: FD event {:?} event: {:?}", item.tag(), id, ev); + + // handle HUP + if ev.is_interrupt() { + item.context.stopped(None); + close(id as u32, item, &self.inner.api, None, true); + return; + } + + let mut renew_ev = Event::new(0, false, false).with_interrupt(); + + if ev.readable { + let res = item.context.with_read_buf(|buf| { + let chunk = buf.chunk_mut(); + let result = task::ready!(syscall!( + break libc::read(item.fd, chunk.as_mut_ptr() as _, chunk.len()) + )); + if let Ok(size) = result { + log::debug!("{}: data {:?}, s: {:?}", item.tag(), item.fd, size); + unsafe { buf.advance_mut(size) }; + } + Poll::Ready(result) + }); + + if res.is_pending() && item.context.is_read_ready() { + renew_ev.readable = true; + item.flags.insert(Flags::RD); + } else { + item.flags.remove(Flags::RD); + } + } else if item.flags.contains(Flags::RD) { + renew_ev.readable = true; + } + + if ev.writable { + let result = item.context.with_write_buf(|buf| { + log::debug!("{}: write {:?} s: {:?}", item.tag(), item.fd, buf.len()); + syscall!(break libc::write(item.fd, buf[..].as_ptr() as _, buf.len())) + }); + if result.is_pending() { + renew_ev.writable = true; + item.flags.insert(Flags::WR); + } else { + item.flags.remove(Flags::WR); + } + } else if item.flags.contains(Flags::WR) { + renew_ev.writable = true; + } + + self.inner.api.modify(item.fd, id as u32, renew_ev); + + // delayed drops + if self.inner.delayd_drop.get() { + for id in self.inner.feed.borrow_mut().drain(..) { + let item = &mut streams[id as usize]; + item.ref_count -= 1; + if item.ref_count == 0 { + let mut item = streams.remove(id as usize); + log::debug!( + "{}: Drop ({}), {:?}, has-io: {}", + item.tag(), + id, + item.fd, + item.io.is_some() + ); + close(id, &mut item, &self.inner.api, None, true); + } + } + self.inner.delayd_drop.set(false); + } + }); + } + + fn error(&mut self, id: usize, err: io::Error) { + self.inner.with(|streams| { + if let Some(item) = streams.get_mut(id) { + log::debug!( + "{}: FD is failed ({}) {:?}, err: {:?}", + item.tag(), + id, + item.fd, + err + ); + close(id as u32, item, &self.inner.api, Some(err), false); + } + }) + } +} + +impl StreamOpsInner { + fn with(&self, f: F) -> R + where + F: FnOnce(&mut Slab>) -> R, + { + let mut streams = self.streams.take().unwrap(); + let result = f(&mut streams); + self.streams.set(Some(streams)); + result + } +} + +fn close( + id: u32, + item: &mut StreamItem, + api: &DriverApi, + error: Option, + shutdown: bool, +) -> Option>> { + if let Some(io) = item.io.take() { + log::debug!("{}: Closing ({}), {:?}", item.tag(), id, item.fd); + mem::forget(io); + if let Some(err) = error { + item.context.stopped(Some(err)); + } + let fd = item.fd; + api.detach(fd, id); + Some(ntex_rt::spawn_blocking(move || { + if shutdown { + let _ = syscall!(libc::shutdown(fd, libc::SHUT_RDWR)); + } + syscall!(libc::close(fd)) + })) + } else { + None + } +} + +impl StreamCtl { + pub(crate) fn close(self) -> impl Future> { + let id = self.id as usize; + let fut = self.inner.with(|streams| { + let item = &mut streams[id]; + close(self.id, item, &self.inner.api, None, false) + }); + async move { + if let Some(fut) = fut { + fut.await + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + .and_then(crate::helpers::pool_io_err)?; + } + Ok(()) + } + } + + pub(crate) fn with_io(&self, f: F) -> R + where + F: FnOnce(Option<&T>) -> R, + { + self.inner + .with(|streams| f(streams[self.id as usize].io.as_ref())) + } + + pub(crate) fn modify(&self, rd: bool, wr: bool) { + self.inner.with(|streams| { + let item = &mut streams[self.id as usize]; + + log::debug!( + "{}: Modify interest ({}), {:?} rd: {:?}, wr: {:?}", + item.tag(), + self.id, + item.fd, + rd, + wr + ); + + let mut event = Event::new(0, false, false).with_interrupt(); + + if rd { + if item.flags.contains(Flags::RD) { + event.readable = true; + } else { + let res = item.context.with_read_buf(|buf| { + let chunk = buf.chunk_mut(); + let result = task::ready!(syscall!( + break libc::read(item.fd, chunk.as_mut_ptr() as _, chunk.len()) + )); + if let Ok(size) = result { + log::debug!( + "{}: read {:?}, s: {:?}", + item.tag(), + item.fd, + size + ); + unsafe { buf.advance_mut(size) }; + } + Poll::Ready(result) + }); + + if res.is_pending() && item.context.is_read_ready() { + event.readable = true; + item.flags.insert(Flags::RD); + } + } + } + + if wr { + if item.flags.contains(Flags::WR) { + event.writable = true; + } else { + let result = item.context.with_write_buf(|buf| { + log::debug!( + "{}: Writing ({}), buf: {:?}", + item.tag(), + self.id, + buf.len() + ); + syscall!( + break libc::write(item.fd, buf[..].as_ptr() as _, buf.len()) + ) + }); + + if result.is_pending() { + event.writable = true; + item.flags.insert(Flags::WR); + } + } + } + + self.inner.api.modify(item.fd, self.id, event); + }) + } +} + +impl Clone for StreamCtl { + fn clone(&self) -> Self { + self.inner.with(|streams| { + streams[self.id as usize].ref_count += 1; + Self { + id: self.id, + inner: self.inner.clone(), + } + }) + } +} + +impl Drop for StreamCtl { + fn drop(&mut self) { + if let Some(mut streams) = self.inner.streams.take() { + let id = self.id as usize; + streams[id].ref_count -= 1; + if streams[id].ref_count == 0 { + let mut item = streams.remove(id); + log::debug!( + "{}: Drop io ({}), {:?}, has-io: {}", + item.tag(), + self.id, + item.fd, + item.io.is_some() + ); + close(self.id, &mut item, &self.inner.api, None, true); + } + self.inner.streams.set(Some(streams)); + } else { + self.inner.delayd_drop.set(true); + self.inner.feed.borrow_mut().push(self.id); + } + } +} diff --git a/ntex-net/src/rt_polling/io.rs b/ntex-net/src/rt_polling/io.rs new file mode 100644 index 00000000..990dae8f --- /dev/null +++ b/ntex-net/src/rt_polling/io.rs @@ -0,0 +1,101 @@ +use std::{any, future::poll_fn, task::Poll}; + +use ntex_io::{ + types, Handle, IoContext, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus, +}; +use ntex_rt::spawn; +use socket2::Socket; + +use super::driver::{StreamCtl, StreamOps}; + +impl IoStream for super::TcpStream { + fn start(self, read: ReadContext, _: WriteContext) -> Option> { + let io = self.0; + let context = read.context(); + let ctl = StreamOps::current().register(io, context.clone()); + let ctl2 = ctl.clone(); + spawn(async move { run(ctl, context).await }); + + Some(Box::new(HandleWrapper(ctl2))) + } +} + +impl IoStream for super::UnixStream { + fn start(self, read: ReadContext, _: WriteContext) -> Option> { + let io = self.0; + let context = read.context(); + let ctl = StreamOps::current().register(io, context.clone()); + spawn(async move { run(ctl, context).await }); + + None + } +} + +struct HandleWrapper(StreamCtl); + +impl Handle for HandleWrapper { + fn query(&self, id: any::TypeId) -> Option> { + if id == any::TypeId::of::() { + let addr = self.0.with_io(|io| io.and_then(|io| io.peer_addr().ok())); + if let Some(addr) = addr.and_then(|addr| addr.as_socket()) { + return Some(Box::new(types::PeerAddr(addr))); + } + } + None + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum Status { + Shutdown, + Terminate, +} + +async fn run(ctl: StreamCtl, context: IoContext) { + // Handle io read readiness + let st = poll_fn(|cx| { + let mut modify = false; + let mut readable = false; + let mut writable = false; + let read = match context.poll_read_ready(cx) { + Poll::Ready(ReadStatus::Ready) => { + modify = true; + readable = true; + Poll::Pending + } + Poll::Ready(ReadStatus::Terminate) => Poll::Ready(()), + Poll::Pending => { + modify = true; + Poll::Pending + } + }; + + let write = match context.poll_write_ready(cx) { + Poll::Ready(WriteStatus::Ready) => { + modify = true; + writable = true; + Poll::Pending + } + Poll::Ready(WriteStatus::Shutdown) => Poll::Ready(Status::Shutdown), + Poll::Ready(WriteStatus::Terminate) => Poll::Ready(Status::Terminate), + Poll::Pending => Poll::Pending, + }; + + if modify { + ctl.modify(readable, writable); + } + + if read.is_pending() && write.is_pending() { + Poll::Pending + } else if write.is_ready() { + write + } else { + Poll::Ready(Status::Terminate) + } + }) + .await; + + ctl.modify(false, true); + context.shutdown(st == Status::Shutdown).await; + context.stopped(ctl.close().await.err()); +} diff --git a/ntex-net/src/rt_polling/mod.rs b/ntex-net/src/rt_polling/mod.rs new file mode 100644 index 00000000..b4fb928b --- /dev/null +++ b/ntex-net/src/rt_polling/mod.rs @@ -0,0 +1,69 @@ +use std::{io::Result, net, net::SocketAddr}; + +use ntex_bytes::PoolRef; +use ntex_io::Io; +use socket2::Socket; + +pub(crate) mod connect; +mod driver; +mod io; + +#[cfg(not(target_pointer_width = "64"))] +compile_error!("Only 64bit platforms are supported"); + +/// Tcp stream wrapper for neon TcpStream +struct TcpStream(socket2::Socket); + +/// Tcp stream wrapper for neon UnixStream +struct UnixStream(socket2::Socket); + +/// Opens a TCP connection to a remote host. +pub async fn tcp_connect(addr: SocketAddr) -> Result { + let sock = crate::helpers::connect(addr).await?; + Ok(Io::new(TcpStream(crate::helpers::prep_socket(sock)?))) +} + +/// Opens a TCP connection to a remote host and use specified memory pool. +pub async fn tcp_connect_in(addr: SocketAddr, pool: PoolRef) -> Result { + let sock = crate::helpers::connect(addr).await?; + Ok(Io::with_memory_pool( + TcpStream(crate::helpers::prep_socket(sock)?), + pool, + )) +} + +/// Opens a unix stream connection. +pub async fn unix_connect<'a, P>(addr: P) -> Result +where + P: AsRef + 'a, +{ + let sock = crate::helpers::connect_unix(addr).await?; + Ok(Io::new(UnixStream(crate::helpers::prep_socket(sock)?))) +} + +/// Opens a unix stream connection and specified memory pool. +pub async fn unix_connect_in<'a, P>(addr: P, pool: PoolRef) -> Result +where + P: AsRef + 'a, +{ + let sock = crate::helpers::connect_unix(addr).await?; + Ok(Io::with_memory_pool( + UnixStream(crate::helpers::prep_socket(sock)?), + pool, + )) +} + +/// Convert std TcpStream to TcpStream +pub fn from_tcp_stream(stream: net::TcpStream) -> Result { + stream.set_nodelay(true)?; + Ok(Io::new(TcpStream(crate::helpers::prep_socket( + Socket::from(stream), + )?))) +} + +/// Convert std UnixStream to UnixStream +pub fn from_unix_stream(stream: std::os::unix::net::UnixStream) -> Result { + Ok(Io::new(UnixStream(crate::helpers::prep_socket( + Socket::from(stream), + )?))) +} diff --git a/ntex-net/src/rt_uring/connect.rs b/ntex-net/src/rt_uring/connect.rs new file mode 100644 index 00000000..ea9be3e1 --- /dev/null +++ b/ntex-net/src/rt_uring/connect.rs @@ -0,0 +1,91 @@ +use std::{cell::RefCell, io, os::fd::RawFd, rc::Rc}; + +use io_uring::{opcode, types::Fd}; +use ntex_neon::{driver::DriverApi, driver::Handler, Runtime}; +use ntex_util::channel::oneshot::Sender; +use slab::Slab; +use socket2::SockAddr; + +#[derive(Clone)] +pub(crate) struct ConnectOps(Rc); + +#[derive(Debug)] +enum Change { + Readable, + Writable, + Error(io::Error), +} + +struct ConnectOpsHandler { + inner: Rc, +} + +type Operations = RefCell, Sender>)>>; + +struct ConnectOpsInner { + api: DriverApi, + ops: Operations, +} + +impl ConnectOps { + pub(crate) fn current() -> Self { + Runtime::value(|rt| { + let mut inner = None; + rt.driver().register(|api| { + if !api.is_supported(opcode::Connect::CODE) { + panic!("opcode::Connect is required for io-uring support"); + } + + let ops = Rc::new(ConnectOpsInner { + api, + ops: RefCell::new(Slab::new()), + }); + inner = Some(ops.clone()); + Box::new(ConnectOpsHandler { inner: ops }) + }); + ConnectOps(inner.unwrap()) + }) + } + + pub(crate) fn connect( + &self, + fd: RawFd, + addr: SockAddr, + sender: Sender>, + ) -> io::Result<()> { + let addr2 = addr.clone(); + let mut ops = self.0.ops.borrow_mut(); + + // addr must be stable, neon submits ops at the end of rt turn + let addr = Box::new(addr); + let (addr_ptr, addr_len) = (addr.as_ref().as_ptr(), addr.len()); + + let id = ops.insert((addr, sender)); + self.0.api.submit( + id as u32, + opcode::Connect::new(Fd(fd), addr_ptr, addr_len).build(), + ); + + Ok(()) + } +} + +impl Handler for ConnectOpsHandler { + fn canceled(&mut self, user_data: usize) { + log::debug!("connect-op is canceled {:?}", user_data); + + self.inner.ops.borrow_mut().remove(user_data); + } + + fn completed(&mut self, user_data: usize, flags: u32, result: io::Result) { + let (addr, tx) = self.inner.ops.borrow_mut().remove(user_data); + log::debug!( + "connect-op is completed {:?} result: {:?}, addr: {:?}", + user_data, + result, + addr.as_socket() + ); + + let _ = tx.send(result.map(|_| ())); + } +} diff --git a/ntex-net/src/rt_uring/driver.rs b/ntex-net/src/rt_uring/driver.rs new file mode 100644 index 00000000..d39d69e8 --- /dev/null +++ b/ntex-net/src/rt_uring/driver.rs @@ -0,0 +1,444 @@ +use std::{cell::RefCell, io, mem, num::NonZeroU32, os, rc::Rc, task::Poll}; + +use io_uring::{opcode, squeue::Entry, types::Fd}; +use ntex_neon::{driver::DriverApi, driver::Handler, Runtime}; +use ntex_util::channel::oneshot; +use slab::Slab; + +use ntex_bytes::{Buf, BufMut, BytesVec}; +use ntex_io::IoContext; + +pub(crate) struct StreamCtl { + id: usize, + inner: Rc>, +} + +bitflags::bitflags! { + #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] + struct Flags: u8 { + const RD_CANCELING = 0b0000_0001; + const RD_REISSUE = 0b0000_0010; + const WR_CANCELING = 0b0001_0000; + const WR_REISSUE = 0b0010_0000; + } +} + +struct StreamItem { + io: Option, + fd: Fd, + context: IoContext, + ref_count: usize, + flags: Flags, + rd_op: Option, + wr_op: Option, +} + +impl StreamItem { + fn tag(&self) -> &'static str { + self.context.tag() + } +} + +enum Operation { + Recv { + id: usize, + buf: BytesVec, + context: IoContext, + }, + Send { + id: usize, + buf: BytesVec, + context: IoContext, + }, + Close { + tx: Option>>, + }, + Nop, +} + +pub(crate) struct StreamOps(Rc>); + +struct StreamOpsHandler { + inner: Rc>, +} + +struct StreamOpsInner { + api: DriverApi, + feed: RefCell>, + storage: RefCell>, +} + +struct StreamOpsStorage { + ops: Slab, + streams: Slab>, +} + +impl StreamOps { + pub(crate) fn current() -> Self { + Runtime::value(|rt| { + let mut inner = None; + rt.driver().register(|api| { + if !api.is_supported(opcode::Recv::CODE) { + panic!("opcode::Recv is required for io-uring support"); + } + if !api.is_supported(opcode::Send::CODE) { + panic!("opcode::Send is required for io-uring support"); + } + if !api.is_supported(opcode::Close::CODE) { + panic!("opcode::Close is required for io-uring support"); + } + + let mut ops = Slab::new(); + ops.insert(Operation::Nop); + + let ops = Rc::new(StreamOpsInner { + api, + feed: RefCell::new(Vec::new()), + storage: RefCell::new(StreamOpsStorage { + ops, + streams: Slab::new(), + }), + }); + inner = Some(ops.clone()); + Box::new(StreamOpsHandler { inner: ops }) + }); + + StreamOps(inner.unwrap()) + }) + } + + pub(crate) fn register(&self, io: T, context: IoContext) -> StreamCtl { + let item = StreamItem { + context, + fd: Fd(io.as_raw_fd()), + io: Some(io), + ref_count: 1, + rd_op: None, + wr_op: None, + flags: Flags::empty(), + }; + let id = self.0.storage.borrow_mut().streams.insert(item); + StreamCtl { + id, + inner: self.0.clone(), + } + } + + fn with(&self, f: F) -> R + where + F: FnOnce(&mut StreamOpsStorage) -> R, + { + f(&mut *self.0.storage.borrow_mut()) + } +} + +impl Clone for StreamOps { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Handler for StreamOpsHandler { + fn canceled(&mut self, user_data: usize) { + let mut storage = self.inner.storage.borrow_mut(); + + match storage.ops.remove(user_data) { + Operation::Recv { id, buf, context } => { + log::debug!("{}: Recv canceled {:?}", context.tag(), id); + context.release_read_buf(buf); + if let Some(item) = storage.streams.get_mut(id) { + item.rd_op.take(); + item.flags.remove(Flags::RD_CANCELING); + if item.flags.contains(Flags::RD_REISSUE) { + item.flags.remove(Flags::RD_REISSUE); + + let result = storage.recv(id, Some(context)); + if let Some((id, op)) = result { + self.inner.api.submit(id, op); + } + } + } + } + Operation::Send { id, buf, context } => { + log::debug!("{}: Send canceled: {:?}", context.tag(), id); + context.release_write_buf(buf); + if let Some(item) = storage.streams.get_mut(id) { + item.wr_op.take(); + item.flags.remove(Flags::WR_CANCELING); + if item.flags.contains(Flags::WR_REISSUE) { + item.flags.remove(Flags::WR_REISSUE); + + let result = storage.send(id, Some(context)); + if let Some((id, op)) = result { + self.inner.api.submit(id, op); + } + } + } + } + Operation::Nop | Operation::Close { .. } => {} + } + } + + fn completed(&mut self, user_data: usize, flags: u32, result: io::Result) { + let mut storage = self.inner.storage.borrow_mut(); + + let op = storage.ops.remove(user_data); + match op { + Operation::Recv { + id, + mut buf, + context, + } => { + let result = result.map(|size| { + unsafe { buf.advance_mut(size as usize) }; + size as usize + }); + + // reset op reference + if let Some(item) = storage.streams.get_mut(id) { + log::debug!( + "{}: Recv completed {:?}, res: {:?}, buf({})", + context.tag(), + item.fd, + result, + buf.remaining_mut() + ); + item.rd_op.take(); + } + + // set read buf + let tag = context.tag(); + if context.set_read_buf(result, buf).is_pending() { + if let Some((id, op)) = storage.recv(id, Some(context)) { + self.inner.api.submit(id, op); + } + } else { + log::debug!("{}: Recv to pause", tag); + } + } + Operation::Send { id, buf, context } => { + // reset op reference + let fd = if let Some(item) = storage.streams.get_mut(id) { + log::debug!( + "{}: Send completed: {:?}, res: {:?}, buf({})", + context.tag(), + item.fd, + result, + buf.len() + ); + item.wr_op.take(); + Some(item.fd) + } else { + None + }; + + // set read buf + let result = context.set_write_buf(result.map(|size| size as usize), buf); + if result.is_pending() { + log::debug!("{}: Need to send more: {:?}", context.tag(), fd); + if let Some((id, op)) = storage.send(id, Some(context)) { + self.inner.api.submit(id, op); + } + } + } + Operation::Close { tx } => { + if let Some(tx) = tx { + let _ = tx.send(result); + } + } + Operation::Nop => {} + } + + // extra + for id in self.inner.feed.borrow_mut().drain(..) { + storage.streams[id].ref_count -= 1; + if storage.streams[id].ref_count == 0 { + let mut item = storage.streams.remove(id); + + log::debug!("{}: Drop io ({}), {:?}", item.tag(), id, item.fd); + + if let Some(io) = item.io.take() { + mem::forget(io); + + let id = storage.ops.insert(Operation::Close { tx: None }); + assert!(id < u32::MAX as usize); + self.inner + .api + .submit(id as u32, opcode::Close::new(item.fd).build()); + } + } + } + } +} + +impl StreamOpsStorage { + fn recv(&mut self, id: usize, context: Option) -> Option<(u32, Entry)> { + let item = &mut self.streams[id]; + + if item.rd_op.is_none() { + if let Poll::Ready(mut buf) = item.context.get_read_buf() { + log::debug!( + "{}: Recv resume ({}), {:?} rem: {:?}", + item.tag(), + id, + item.fd, + buf.remaining_mut() + ); + + let slice = buf.chunk_mut(); + let op = opcode::Recv::new(item.fd, slice.as_mut_ptr(), slice.len() as u32) + .build(); + + let op_id = self.ops.insert(Operation::Recv { + id, + buf, + context: context.unwrap_or_else(|| item.context.clone()), + }); + assert!(op_id < u32::MAX as usize); + + item.rd_op = NonZeroU32::new(op_id as u32); + return Some((op_id as u32, op)); + } + } else if item.flags.contains(Flags::RD_CANCELING) { + item.flags.insert(Flags::RD_REISSUE); + } + None + } + + fn send(&mut self, id: usize, context: Option) -> Option<(u32, Entry)> { + let item = &mut self.streams[id]; + + if item.wr_op.is_none() { + if let Poll::Ready(buf) = item.context.get_write_buf() { + log::debug!( + "{}: Send resume ({}), {:?} len: {:?}", + item.tag(), + id, + item.fd, + buf.len() + ); + + let slice = buf.chunk(); + let op = + opcode::Send::new(item.fd, slice.as_ptr(), slice.len() as u32).build(); + + let op_id = self.ops.insert(Operation::Send { + id, + buf, + context: context.unwrap_or_else(|| item.context.clone()), + }); + assert!(op_id < u32::MAX as usize); + + item.wr_op = NonZeroU32::new(op_id as u32); + return Some((op_id as u32, op)); + } + } else if item.flags.contains(Flags::WR_CANCELING) { + item.flags.insert(Flags::WR_REISSUE); + } + None + } +} + +impl StreamCtl { + pub(crate) async fn close(self) -> io::Result<()> { + let result = { + let mut storage = self.inner.storage.borrow_mut(); + + let (io, fd) = { + let item = &mut storage.streams[self.id]; + (item.io.take(), item.fd) + }; + if let Some(io) = io { + mem::forget(io); + + let (tx, rx) = oneshot::channel(); + let id = storage.ops.insert(Operation::Close { tx: Some(tx) }); + assert!(id < u32::MAX as usize); + + drop(storage); + self.inner + .api + .submit(id as u32, opcode::Close::new(fd).build()); + Some(rx) + } else { + None + } + }; + + if let Some(rx) = result { + rx.await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "gone")) + .and_then(|item| item) + .map(|_| ()) + } else { + Ok(()) + } + } + + pub(crate) fn with_io(&self, f: F) -> R + where + F: FnOnce(Option<&T>) -> R, + { + f(self.inner.storage.borrow().streams[self.id].io.as_ref()) + } + + pub(crate) fn resume_read(&self) { + let result = self.inner.storage.borrow_mut().recv(self.id, None); + if let Some((id, op)) = result { + self.inner.api.submit(id, op); + } + } + + pub(crate) fn resume_write(&self) { + let result = self.inner.storage.borrow_mut().send(self.id, None); + if let Some((id, op)) = result { + self.inner.api.submit(id, op); + } + } + + pub(crate) fn pause_read(&self) { + let mut storage = self.inner.storage.borrow_mut(); + let item = &mut storage.streams[self.id]; + + if let Some(rd_op) = item.rd_op { + if !item.flags.contains(Flags::RD_CANCELING) { + log::debug!("{}: Recv to pause ({}), {:?}", item.tag(), self.id, item.fd); + item.flags.insert(Flags::RD_CANCELING); + self.inner.api.cancel(rd_op.get()); + } + } + } +} + +impl Clone for StreamCtl { + fn clone(&self) -> Self { + self.inner.storage.borrow_mut().streams[self.id].ref_count += 1; + Self { + id: self.id, + inner: self.inner.clone(), + } + } +} + +impl Drop for StreamCtl { + fn drop(&mut self) { + if let Ok(mut storage) = self.inner.storage.try_borrow_mut() { + storage.streams[self.id].ref_count -= 1; + if storage.streams[self.id].ref_count == 0 { + let mut item = storage.streams.remove(self.id); + if let Some(io) = item.io.take() { + log::debug!("{}: Close io ({}), {:?}", item.tag(), self.id, item.fd); + mem::forget(io); + + let id = storage.ops.insert(Operation::Close { tx: None }); + assert!(id < u32::MAX as usize); + self.inner + .api + .submit(id as u32, opcode::Close::new(item.fd).build()); + } + } + } else { + self.inner.feed.borrow_mut().push(self.id); + } + } +} diff --git a/ntex-net/src/rt_uring/io.rs b/ntex-net/src/rt_uring/io.rs new file mode 100644 index 00000000..2f111ad7 --- /dev/null +++ b/ntex-net/src/rt_uring/io.rs @@ -0,0 +1,95 @@ +use std::{any, future::poll_fn, task::Poll}; + +use ntex_io::{ + types, Handle, IoContext, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus, +}; +use ntex_rt::spawn; +use socket2::Socket; + +use super::driver::{StreamCtl, StreamOps}; + +impl IoStream for super::TcpStream { + fn start(self, read: ReadContext, _: WriteContext) -> Option> { + let io = self.0; + let context = read.context(); + let ctl = StreamOps::current().register(io, context.clone()); + let ctl2 = ctl.clone(); + spawn(async move { run(ctl, context).await }); + + Some(Box::new(HandleWrapper(ctl2))) + } +} + +impl IoStream for super::UnixStream { + fn start(self, read: ReadContext, _: WriteContext) -> Option> { + let io = self.0; + let context = read.context(); + let ctl = StreamOps::current().register(io, context.clone()); + spawn(async move { run(ctl, context).await }); + + None + } +} + +struct HandleWrapper(StreamCtl); + +impl Handle for HandleWrapper { + fn query(&self, id: any::TypeId) -> Option> { + if id == any::TypeId::of::() { + let addr = self.0.with_io(|io| io.and_then(|io| io.peer_addr().ok())); + if let Some(addr) = addr.and_then(|addr| addr.as_socket()) { + return Some(Box::new(types::PeerAddr(addr))); + } + } + None + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum Status { + Shutdown, + Terminate, +} + +async fn run(ctl: StreamCtl, context: IoContext) { + // Handle io readiness + let st = poll_fn(|cx| { + let read = match context.poll_read_ready(cx) { + Poll::Ready(ReadStatus::Ready) => { + ctl.resume_read(); + Poll::Pending + } + Poll::Ready(ReadStatus::Terminate) => Poll::Ready(()), + Poll::Pending => { + ctl.pause_read(); + Poll::Pending + } + }; + + let write = match context.poll_write_ready(cx) { + Poll::Ready(WriteStatus::Ready) => { + ctl.resume_write(); + Poll::Pending + } + Poll::Ready(WriteStatus::Shutdown) => Poll::Ready(Status::Shutdown), + Poll::Ready(WriteStatus::Terminate) => Poll::Ready(Status::Terminate), + Poll::Pending => Poll::Pending, + }; + + if read.is_pending() && write.is_pending() { + Poll::Pending + } else if write.is_ready() { + write + } else { + Poll::Ready(Status::Terminate) + } + }) + .await; + + ctl.pause_read(); + ctl.resume_write(); + context.shutdown(st == Status::Shutdown).await; + + let result = ctl.close().await; + context.stopped(result.err()); +} diff --git a/ntex-net/src/rt_uring/mod.rs b/ntex-net/src/rt_uring/mod.rs new file mode 100644 index 00000000..41016d09 --- /dev/null +++ b/ntex-net/src/rt_uring/mod.rs @@ -0,0 +1,66 @@ +use std::{io::Result, net, net::SocketAddr}; + +use ntex_bytes::PoolRef; +use ntex_io::Io; +use socket2::Socket; + +pub(crate) mod connect; +mod driver; +mod io; + +/// Tcp stream wrapper for neon TcpStream +struct TcpStream(Socket); + +/// Tcp stream wrapper for neon UnixStream +struct UnixStream(Socket); + +/// Opens a TCP connection to a remote host. +pub async fn tcp_connect(addr: SocketAddr) -> Result { + let sock = crate::helpers::connect(addr).await?; + Ok(Io::new(TcpStream(crate::helpers::prep_socket(sock)?))) +} + +/// Opens a TCP connection to a remote host and use specified memory pool. +pub async fn tcp_connect_in(addr: SocketAddr, pool: PoolRef) -> Result { + let sock = crate::helpers::connect(addr).await?; + Ok(Io::with_memory_pool( + TcpStream(crate::helpers::prep_socket(sock)?), + pool, + )) +} + +/// Opens a unix stream connection. +pub async fn unix_connect<'a, P>(addr: P) -> Result +where + P: AsRef + 'a, +{ + let sock = crate::helpers::connect_unix(addr).await?; + Ok(Io::new(UnixStream(crate::helpers::prep_socket(sock)?))) +} + +/// Opens a unix stream connection and specified memory pool. +pub async fn unix_connect_in<'a, P>(addr: P, pool: PoolRef) -> Result +where + P: AsRef + 'a, +{ + let sock = crate::helpers::connect_unix(addr).await?; + Ok(Io::with_memory_pool( + UnixStream(crate::helpers::prep_socket(sock)?), + pool, + )) +} + +/// Convert std TcpStream to tokio's TcpStream +pub fn from_tcp_stream(stream: net::TcpStream) -> Result { + stream.set_nodelay(true)?; + Ok(Io::new(TcpStream(crate::helpers::prep_socket( + Socket::from(stream), + )?))) +} + +/// Convert std UnixStream to tokio's UnixStream +pub fn from_unix_stream(stream: std::os::unix::net::UnixStream) -> Result { + Ok(Io::new(UnixStream(crate::helpers::prep_socket( + Socket::from(stream), + )?))) +} diff --git a/ntex-router/src/lib.rs b/ntex-router/src/lib.rs index c0951890..006b4063 100644 --- a/ntex-router/src/lib.rs +++ b/ntex-router/src/lib.rs @@ -1,9 +1,4 @@ -#![deny( - rust_2018_idioms, - warnings, - unreachable_pub, - missing_debug_implementations -)] +#![deny(warnings, unreachable_pub, missing_debug_implementations)] #![warn(nonstandard_style, future_incompatible)] //! Resource path matching library. @@ -42,7 +37,7 @@ impl ResourcePath for String { } } -impl<'a> ResourcePath for &'a str { +impl ResourcePath for &str { fn path(&self) -> &str { self } @@ -54,7 +49,7 @@ impl ResourcePath for ntex_bytes::ByteString { } } -impl<'a, T: ResourcePath> ResourcePath for &'a T { +impl ResourcePath for &T { fn path(&self) -> &str { (*self).path() } @@ -71,13 +66,13 @@ impl IntoPattern for String { } } -impl<'a> IntoPattern for &'a String { +impl IntoPattern for &String { fn patterns(&self) -> Vec { vec![self.as_str().to_string()] } } -impl<'a> IntoPattern for &'a str { +impl IntoPattern for &str { fn patterns(&self) -> Vec { vec![(*self).to_string()] } diff --git a/ntex-router/src/quoter.rs b/ntex-router/src/quoter.rs index f619ea37..11593c3a 100644 --- a/ntex-router/src/quoter.rs +++ b/ntex-router/src/quoter.rs @@ -63,5 +63,5 @@ fn from_hex(v: u8) -> Option { #[inline] fn restore_ch(d1: u8, d2: u8) -> Option { - from_hex(d1).and_then(|d1| from_hex(d2).map(move |d2| d1 << 4 | d2)) + from_hex(d1).and_then(|d1| from_hex(d2).map(move |d2| (d1 << 4) | d2)) } diff --git a/ntex-rt/CHANGES.md b/ntex-rt/CHANGES.md index c2201fe0..2afd5bd6 100644 --- a/ntex-rt/CHANGES.md +++ b/ntex-rt/CHANGES.md @@ -1,5 +1,45 @@ # Changes +## [0.4.29] - 2025-03-26 + +* Add Arbiter::get_value() helper method + +## [0.4.27] - 2025-03-14 + +* Add srbiters pings ttl + +* Retrieves a list of all arbiters in the system + +* Add "neon" runtime support + +* Drop glommio support + +* Drop async-std support + +## [0.4.26] - 2025-03-12 + +* Add Arbiter::spawn_with() + +## [0.4.25] - 2025-03-11 + +* Adds Send bound to arbiter exec (#514) + +## [0.4.24] - 2025-01-03 + +* Relax runtime requirements + +## [0.4.23] - 2024-12-10 + +* Remove Unpin requirements for Arbiter::spawn() + +## [0.4.22] - 2024-12-01 + +* Depend on individual compio packages + +## [0.4.21] - 2024-11-25 + +* Update to compio 0.13 + ## [0.4.20] - 2024-10-17 * Allow to skip runtime feature for clippy run diff --git a/ntex-rt/Cargo.toml b/ntex-rt/Cargo.toml index 08b5d49d..a5966d76 100644 --- a/ntex-rt/Cargo.toml +++ b/ntex-rt/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-rt" -version = "0.4.20" +version = "0.4.29" authors = ["ntex contributors "] description = "ntex runtime" keywords = ["network", "framework", "async", "futures"] @@ -20,33 +20,26 @@ path = "src/lib.rs" [features] default = [] -# glommio support -glommio = ["glomm-io", "futures-channel"] - # tokio support tokio = ["tok-io"] # compio support -compio = ["comp-io"] +compio = ["compio-driver", "compio-runtime"] -# async-std support -async-std = ["async_std/unstable"] +# neon runtime +neon = ["ntex-neon"] [dependencies] async-channel = "2" -futures-core = "0.3" -log = "0.4" +futures-timer = "3.0" oneshot = "0.1" +log = "0.4" -async_std = { version = "1", package = "async-std", optional = true } -comp-io = { version = "0.12", package = "compio", default-features = false, features = [ - "runtime" -], optional = true } +compio-driver = { version = "0.6", optional = true } +compio-runtime = { version = "0.6", optional = true } tok-io = { version = "1", package = "tokio", default-features = false, features = [ "rt", "net", ], optional = true } -[target.'cfg(target_os = "linux")'.dependencies] -glomm-io = { version = "0.9", package = "glommio", optional = true } -futures-channel = { version = "0.3", optional = true } +ntex-neon = { version = "0.1.14", optional = true } diff --git a/ntex-rt/build.rs b/ntex-rt/build.rs index 2e51b24f..f4ea08b7 100644 --- a/ntex-rt/build.rs +++ b/ntex-rt/build.rs @@ -1,33 +1,21 @@ use std::{collections::HashSet, env}; fn main() { - let mut clippy = false; let mut features = HashSet::<&'static str>::default(); - for (key, val) in env::vars() { + for (key, _) in env::vars() { let _ = match key.as_ref() { "CARGO_FEATURE_COMPIO" => features.insert("compio"), "CARGO_FEATURE_TOKIO" => features.insert("tokio"), - "CARGO_FEATURE_GLOMMIO" => features.insert("glommio"), - "CARGO_FEATURE_ASYNC_STD" => features.insert("async-std"), - "CARGO_CFG_FEATURE" => { - if val.contains("cargo-clippy") { - clippy = true; - } - false - } + "CARGO_FEATURE_NEON" => features.insert("neon"), _ => false, }; } - if !clippy { - if features.is_empty() { - panic!("Runtime must be selected '--feature=ntex/$runtime', available options are \"compio\", \"tokio\", \"async-std\", \"glommio\""); - } else if features.len() > 1 { - panic!( - "Only one runtime feature could be selected, current selection {:?}", - features - ); - } + if features.len() > 1 { + panic!( + "Only one runtime feature could be selected, current selection {:?}", + features + ); } } diff --git a/ntex-rt/src/arbiter.rs b/ntex-rt/src/arbiter.rs index f2c5e939..e20ab282 100644 --- a/ntex-rt/src/arbiter.rs +++ b/ntex-rt/src/arbiter.rs @@ -1,27 +1,22 @@ #![allow(clippy::let_underscore_future)] use std::any::{Any, TypeId}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::task::{ready, Context, Poll}; +use std::sync::{atomic::AtomicUsize, atomic::Ordering, Arc}; use std::{cell::RefCell, collections::HashMap, fmt, future::Future, pin::Pin, thread}; use async_channel::{unbounded, Receiver, Sender}; -use futures_core::stream::Stream; -use crate::system::System; +use crate::system::{FnExec, Id, System, SystemCommand}; thread_local!( static ADDR: RefCell> = const { RefCell::new(None) }; static STORAGE: RefCell>> = RefCell::new(HashMap::new()); ); -type ServerCommandRx = Pin>>; -type ArbiterCommandRx = Pin>>; - pub(super) static COUNT: AtomicUsize = AtomicUsize::new(0); pub(super) enum ArbiterCommand { Stop, - Execute(Box + Unpin + Send>), + Execute(Pin + Send>>), ExecuteFn(Box), } @@ -31,13 +26,16 @@ pub(super) enum ArbiterCommand { /// When an Arbiter is created, it spawns a new OS thread, and /// hosts an event loop. Some Arbiter functions execute on the current thread. pub struct Arbiter { + id: usize, + pub(crate) sys_id: usize, + name: Arc, sender: Sender, thread_handle: Option>, } impl fmt::Debug for Arbiter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Arbiter") + write!(f, "Arbiter({:?})", self.name.as_ref()) } } @@ -49,26 +47,20 @@ impl Default for Arbiter { impl Clone for Arbiter { fn clone(&self) -> Self { - Self::with_sender(self.sender.clone()) + Self::with_sender(self.sys_id, self.id, self.name.clone(), self.sender.clone()) } } impl Arbiter { #[allow(clippy::borrowed_box)] - pub(super) fn new_system() -> (Self, ArbiterController) { + pub(super) fn new_system(name: String) -> (Self, ArbiterController) { let (tx, rx) = unbounded(); - let arb = Arbiter::with_sender(tx); + let arb = Arbiter::with_sender(0, 0, Arc::new(name), tx); ADDR.with(|cell| *cell.borrow_mut() = Some(arb.clone())); STORAGE.with(|cell| cell.borrow_mut().clear()); - ( - arb, - ArbiterController { - stop: None, - rx: Box::pin(rx), - }, - ) + (arb, ArbiterController { rx, stop: None }) } /// Returns the current thread's arbiter's address. If no Arbiter is present, then this @@ -85,27 +77,37 @@ impl Arbiter { let _ = self.sender.try_send(ArbiterCommand::Stop); } - /// Spawn new thread and run event loop in spawned thread. + /// Spawn new thread and run runtime in spawned thread. /// Returns address of newly created arbiter. pub fn new() -> Arbiter { + let name = format!("ntex-rt:worker:{}", COUNT.load(Ordering::Relaxed) + 1); + Arbiter::with_name(name) + } + + /// Spawn new thread and run runtime in spawned thread. + /// Returns address of newly created arbiter. + pub fn with_name(name: String) -> Arbiter { let id = COUNT.fetch_add(1, Ordering::Relaxed); - let name = format!("ntex-rt:worker:{}", id); let sys = System::current(); + let name2 = Arc::new(name.clone()); let config = sys.config(); let (arb_tx, arb_rx) = unbounded(); let arb_tx2 = arb_tx.clone(); let builder = if sys.config().stack_size > 0 { thread::Builder::new() - .name(name.clone()) + .name(name) .stack_size(sys.config().stack_size) } else { - thread::Builder::new().name(name.clone()) + thread::Builder::new().name(name) }; + let name = name2.clone(); + let sys_id = sys.id(); + let handle = builder .spawn(move || { - let arb = Arbiter::with_sender(arb_tx); + let arb = Arbiter::with_sender(sys_id.0, id, name2, arb_tx); let (stop, stop_rx) = oneshot::channel(); STORAGE.with(|cell| cell.borrow_mut().clear()); @@ -114,16 +116,19 @@ impl Arbiter { config.block_on(async move { // start arbiter controller - let _ = crate::spawn(ArbiterController { - stop: Some(stop), - rx: Box::pin(arb_rx), - }); + let _ = crate::spawn( + ArbiterController { + stop: Some(stop), + rx: arb_rx, + } + .run(), + ); ADDR.with(|cell| *cell.borrow_mut() = Some(arb.clone())); // register arbiter let _ = System::current() .sys() - .try_send(SystemCommand::RegisterArbiter(id, arb)); + .try_send(SystemCommand::RegisterArbiter(Id(id), arb)); // run loop let _ = stop_rx.await; @@ -132,32 +137,84 @@ impl Arbiter { // unregister arbiter let _ = System::current() .sys() - .try_send(SystemCommand::UnregisterArbiter(id)); + .try_send(SystemCommand::UnregisterArbiter(Id(id))); }) .unwrap_or_else(|err| { panic!("Cannot spawn an arbiter's thread {:?}: {:?}", &name, err) }); Arbiter { + id, + name, + sys_id: sys_id.0, sender: arb_tx2, thread_handle: Some(handle), } } + fn with_sender( + sys_id: usize, + id: usize, + name: Arc, + sender: Sender, + ) -> Self { + Self { + id, + sys_id, + name, + sender, + thread_handle: None, + } + } + + /// Id of the arbiter + pub fn id(&self) -> Id { + Id(self.id) + } + + /// Name of the arbiter + pub fn name(&self) -> &str { + self.name.as_ref() + } + /// Send a future to the Arbiter's thread, and spawn it. pub fn spawn(&self, future: F) where - F: Future + Send + Unpin + 'static, + F: Future + Send + 'static, { let _ = self .sender - .try_send(ArbiterCommand::Execute(Box::new(future))); + .try_send(ArbiterCommand::Execute(Box::pin(future))); } + #[rustfmt::skip] + /// Send a function to the Arbiter's thread and spawns it's resulting future. + /// This can be used to spawn non-send futures on the arbiter thread. + pub fn spawn_with( + &self, + f: F + ) -> impl Future> + Send + 'static + where + F: FnOnce() -> R + Send + 'static, + R: Future + 'static, + O: Send + 'static, + { + let (tx, rx) = oneshot::channel(); + let _ = self + .sender + .try_send(ArbiterCommand::ExecuteFn(Box::new(move || { + crate::spawn(async move { + let _ = tx.send(f().await); + }); + }))); + rx + } + + #[rustfmt::skip] /// Send a function to the Arbiter's thread. This function will be executed asynchronously. /// A future is created, and when resolved will contain the result of the function sent /// to the Arbiters thread. - pub fn exec(&self, f: F) -> impl Future> + pub fn exec(&self, f: F) -> impl Future> + Send + 'static where F: FnOnce() -> R + Send + 'static, R: Send + 'static, @@ -229,11 +286,23 @@ impl Arbiter { }) } - fn with_sender(sender: Sender) -> Self { - Self { - sender, - thread_handle: None, - } + /// Get a type previously inserted to this runtime or create new one. + pub fn get_value(f: F) -> T + where + T: Clone + 'static, + F: FnOnce() -> T, + { + STORAGE.with(move |cell| { + let mut st = cell.borrow_mut(); + if let Some(boxed) = st.get(&TypeId::of::()) { + if let Some(val) = (&**boxed as &(dyn Any + 'static)).downcast_ref::() { + return val.clone(); + } + } + let val = f(); + st.insert(TypeId::of::(), Box::new(val.clone())); + val + }) } /// Wait for the event loop to stop by joining the underlying thread (if have Some). @@ -246,9 +315,17 @@ impl Arbiter { } } +impl Eq for Arbiter {} + +impl PartialEq for Arbiter { + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.sys_id == other.sys_id + } +} + pub(crate) struct ArbiterController { stop: Option>, - rx: ArbiterCommandRx, + rx: Receiver, } impl Drop for ArbiterController { @@ -264,118 +341,28 @@ impl Drop for ArbiterController { } } -impl Future for ArbiterController { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { +impl ArbiterController { + pub(super) async fn run(mut self) { loop { - match Pin::new(&mut self.rx).poll_next(cx) { - Poll::Ready(None) => return Poll::Ready(()), - Poll::Ready(Some(item)) => match item { - ArbiterCommand::Stop => { - if let Some(stop) = self.stop.take() { - let _ = stop.send(0); - }; - return Poll::Ready(()); - } - ArbiterCommand::Execute(fut) => { - let _ = crate::spawn(fut); - } - ArbiterCommand::ExecuteFn(f) => { - f.call_box(); - } - }, - Poll::Pending => return Poll::Pending, - } - } - } -} - -#[derive(Debug)] -pub(super) enum SystemCommand { - Exit(i32), - RegisterArbiter(usize, Arbiter), - UnregisterArbiter(usize), -} - -pub(super) struct SystemArbiter { - stop: Option>, - commands: ServerCommandRx, - arbiters: HashMap, -} - -impl SystemArbiter { - pub(super) fn new( - stop: oneshot::Sender, - commands: Receiver, - ) -> Self { - SystemArbiter { - commands: Box::pin(commands), - stop: Some(stop), - arbiters: HashMap::new(), - } - } -} - -impl fmt::Debug for SystemArbiter { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SystemArbiter") - .field("arbiters", &self.arbiters) - .finish() - } -} - -impl Future for SystemArbiter { - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - let cmd = ready!(Pin::new(&mut self.commands).poll_next(cx)); - log::debug!("Received system command: {:?}", cmd); - match cmd { - None => { - log::debug!("System stopped"); - return Poll::Ready(()); + match self.rx.recv().await { + Ok(ArbiterCommand::Stop) => { + if let Some(stop) = self.stop.take() { + let _ = stop.send(0); + }; + break; } - Some(cmd) => match cmd { - SystemCommand::Exit(code) => { - log::debug!("Stopping system with {} code", code); - - // stop arbiters - for arb in self.arbiters.values() { - arb.stop(); - } - // stop event loop - if let Some(stop) = self.stop.take() { - let _ = stop.send(code); - } - } - SystemCommand::RegisterArbiter(name, hnd) => { - self.arbiters.insert(name, hnd); - } - SystemCommand::UnregisterArbiter(name) => { - self.arbiters.remove(&name); - } - }, + Ok(ArbiterCommand::Execute(fut)) => { + let _ = crate::spawn(fut); + } + Ok(ArbiterCommand::ExecuteFn(f)) => { + f.call_box(); + } + Err(_) => break, } } } } -pub(super) trait FnExec: Send + 'static { - fn call_box(self: Box); -} - -impl FnExec for F -where - F: FnOnce() + Send + 'static, -{ - #[allow(clippy::boxed_local)] - fn call_box(self: Box) { - (*self)() - } -} - #[cfg(test)] mod tests { use super::*; @@ -387,6 +374,7 @@ mod tests { assert!(Arbiter::get_item::<&'static str, _, _>(|s| *s == "test")); assert!(Arbiter::get_mut_item::<&'static str, _, _>(|s| *s == "test")); assert!(Arbiter::contains_item::<&'static str>()); + assert!(Arbiter::get_value(|| 64u64) == 64); assert!(format!("{:?}", Arbiter::current()).contains("Arbiter")); } } diff --git a/ntex-rt/src/builder.rs b/ntex-rt/src/builder.rs index 597e107a..e16c01cb 100644 --- a/ntex-rt/src/builder.rs +++ b/ntex-rt/src/builder.rs @@ -1,9 +1,9 @@ -use std::{future::Future, io, pin::Pin, sync::Arc}; +use std::{future::Future, io, marker::PhantomData, pin::Pin, rc::Rc, sync::Arc}; use async_channel::unbounded; -use crate::arbiter::{Arbiter, ArbiterController, SystemArbiter}; -use crate::{system::SystemConfig, System}; +use crate::arbiter::{Arbiter, ArbiterController}; +use crate::system::{System, SystemCommand, SystemConfig, SystemSupport}; /// Builder struct for a ntex runtime. /// @@ -17,6 +17,8 @@ pub struct Builder { stop_on_panic: bool, /// New thread stack size stack_size: usize, + /// Arbiters ping interval + ping_interval: usize, /// Block on fn block_on: Option>>) + Sync + Send>>, } @@ -28,6 +30,7 @@ impl Builder { stop_on_panic: false, stack_size: 0, block_on: None, + ping_interval: 1000, } } @@ -52,6 +55,15 @@ impl Builder { self } + /// Sets ping interval for spawned arbiters. + /// + /// Interval is in milliseconds. By default 5000 milliseconds is set. + /// To disable pings set value to zero. + pub fn ping_interval(mut self, interval: usize) -> Self { + self.ping_interval = interval; + self + } + /// Use custom block_on function pub fn block_on(mut self, block_on: F) -> Self where @@ -74,18 +86,20 @@ impl Builder { stop_on_panic: self.stop_on_panic, }; - let (arb, arb_controller) = Arbiter::new_system(); - let system = System::construct(sys_sender, arb, config); + let (arb, controller) = Arbiter::new_system(self.name.clone()); + let _ = sys_sender.try_send(SystemCommand::RegisterArbiter(arb.id(), arb.clone())); + let system = System::construct(sys_sender, arb.clone(), config); // system arbiter - let arb = SystemArbiter::new(stop_tx, sys_receiver); + let support = SystemSupport::new(stop_tx, sys_receiver, self.ping_interval); // init system arbiter and run configuration method SystemRunner { stop, - arb, - arb_controller, + support, + controller, system, + _t: PhantomData, } } } @@ -94,9 +108,10 @@ impl Builder { #[must_use = "SystemRunner must be run"] pub struct SystemRunner { stop: oneshot::Receiver, - arb: SystemArbiter, - arb_controller: ArbiterController, + support: SystemSupport, + controller: ArbiterController, system: System, + _t: PhantomData>, } impl SystemRunner { @@ -113,15 +128,14 @@ impl SystemRunner { /// This function will start event loop and will finish once the /// `System::stop()` function is called. - #[inline] pub fn run(self, f: F) -> io::Result<()> where F: FnOnce() -> io::Result<()> + 'static, { let SystemRunner { + controller, stop, - arb, - arb_controller, + support, system, .. } = self; @@ -130,8 +144,8 @@ impl SystemRunner { system.config().block_on(async move { f()?; - let _ = crate::spawn(arb); - let _ = crate::spawn(arb_controller); + let _ = crate::spawn(support.run()); + let _ = crate::spawn(controller.run()); match stop.await { Ok(code) => { if code != 0 { @@ -149,22 +163,21 @@ impl SystemRunner { } /// Execute a future and wait for result. - #[inline] pub fn block_on(self, fut: F) -> R where F: Future + 'static, R: 'static, { let SystemRunner { - arb, - arb_controller, + controller, + support, system, .. } = self; system.config().block_on(async move { - let _ = crate::spawn(arb); - let _ = crate::spawn(arb_controller); + let _ = crate::spawn(support.run()); + let _ = crate::spawn(controller.run()); fut.await }) } @@ -177,16 +190,16 @@ impl SystemRunner { R: 'static, { let SystemRunner { - arb, - arb_controller, + controller, + support, .. } = self; // run loop tok_io::task::LocalSet::new() .run_until(async move { - let _ = crate::spawn(arb); - let _ = crate::spawn(arb_controller); + let _ = crate::spawn(support.run()); + let _ = crate::spawn(controller.run()); fut.await }) .await @@ -242,6 +255,7 @@ mod tests { thread::spawn(move || { let runner = crate::System::build() .stop_on_panic(true) + .ping_interval(25) .block_on(|fut| { let rt = tok_io::runtime::Builder::new_current_thread() .enable_all() @@ -270,6 +284,18 @@ mod tests { .unwrap(); assert_eq!(id, id2); + let (tx, rx) = mpsc::channel(); + sys.arbiter().spawn(async move { + futures_timer::Delay::new(std::time::Duration::from_millis(100)).await; + + let recs = System::list_arbiter_pings(Arbiter::current().id(), |recs| { + recs.unwrap().clone() + }); + let _ = tx.send(recs); + }); + let recs = rx.recv().unwrap(); + + assert!(!recs.is_empty()); sys.stop(); } } diff --git a/ntex-rt/src/lib.rs b/ntex-rt/src/lib.rs index d7c02858..d5d85546 100644 --- a/ntex-rt/src/lib.rs +++ b/ntex-rt/src/lib.rs @@ -8,7 +8,7 @@ mod system; pub use self::arbiter::Arbiter; pub use self::builder::{Builder, SystemRunner}; -pub use self::system::System; +pub use self::system::{Id, PingRecord, System}; thread_local! { static CB: RefCell<(TBefore, TEnter, TExit, TAfter)> = RefCell::new(( @@ -112,6 +112,8 @@ mod tokio { /// /// This function panics if ntex system is not running. #[inline] + #[doc(hidden)] + #[deprecated] pub fn spawn_fn(f: F) -> tok_io::task::JoinHandle where F: FnOnce() -> R + 'static, @@ -127,14 +129,14 @@ mod compio { use std::task::{ready, Context, Poll}; use std::{fmt, future::poll_fn, future::Future, pin::Pin}; - use comp_io::runtime::Runtime; + use compio_runtime::Runtime; /// Runs the provided future, blocking the current thread until the future /// completes. pub fn block_on>(fut: F) { log::info!( "Starting compio runtime, driver {:?}", - comp_io::driver::DriverType::current() + compio_driver::DriverType::current() ); let rt = Runtime::new().unwrap(); rt.block_on(fut); @@ -151,7 +153,7 @@ mod compio { T: Send + 'static, { JoinHandle { - fut: Some(comp_io::runtime::spawn_blocking(f)), + fut: Some(compio_runtime::spawn_blocking(f)), } } @@ -168,7 +170,7 @@ mod compio { F: Future + 'static, { let ptr = crate::CB.with(|cb| (cb.borrow().0)()); - let fut = comp_io::runtime::spawn(async move { + let fut = compio_runtime::spawn(async move { if let Some(ptr) = ptr { let mut f = std::pin::pin!(f); let result = poll_fn(|ctx| { @@ -196,6 +198,8 @@ mod compio { /// /// This function panics if ntex system is not running. #[inline] + #[doc(hidden)] + #[deprecated] pub fn spawn_fn(f: F) -> JoinHandle where F: FnOnce() -> R + 'static, @@ -216,7 +220,7 @@ mod compio { impl std::error::Error for JoinError {} pub struct JoinHandle { - fut: Option>, + fut: Option>, } impl JoinHandle { @@ -248,15 +252,38 @@ mod compio { } #[allow(dead_code)] -#[cfg(feature = "async-std")] -mod asyncstd { - use std::future::{poll_fn, Future}; - use std::{fmt, pin::Pin, task::ready, task::Context, task::Poll}; +#[cfg(feature = "neon")] +mod neon { + use std::task::{ready, Context, Poll}; + use std::{fmt, future::poll_fn, future::Future, pin::Pin}; + + use ntex_neon::Runtime; /// Runs the provided future, blocking the current thread until the future /// completes. pub fn block_on>(fut: F) { - async_std::task::block_on(fut); + let rt = Runtime::new().unwrap(); + log::info!( + "Starting neon runtime, driver {:?}", + rt.driver().tp().name() + ); + + rt.block_on(fut); + } + + /// Spawns a blocking task. + /// + /// The task will be spawned onto a thread pool specifically dedicated + /// to blocking tasks. This is useful to prevent long-running synchronous + /// operations from blocking the main futures executor. + pub fn spawn_blocking(f: F) -> JoinHandle + where + F: FnOnce() -> T + Send + Sync + 'static, + T: Send + 'static, + { + JoinHandle { + fut: Some(ntex_neon::spawn_blocking(f)), + } } /// Spawn a future on the current thread. This does not create a new Arbiter @@ -267,29 +294,29 @@ mod asyncstd { /// /// This function panics if ntex system is not running. #[inline] - pub fn spawn(mut f: F) -> JoinHandle + pub fn spawn(f: F) -> Task where F: Future + 'static, { let ptr = crate::CB.with(|cb| (cb.borrow().0)()); - JoinHandle { - fut: async_std::task::spawn_local(async move { - if let Some(ptr) = ptr { - let mut f = unsafe { Pin::new_unchecked(&mut f) }; - let result = poll_fn(|ctx| { - let new_ptr = crate::CB.with(|cb| (cb.borrow().1)(ptr)); - let result = f.as_mut().poll(ctx); - crate::CB.with(|cb| (cb.borrow().2)(new_ptr)); - result - }) - .await; - crate::CB.with(|cb| (cb.borrow().3)(ptr)); + let task = ntex_neon::spawn(async move { + if let Some(ptr) = ptr { + let mut f = std::pin::pin!(f); + let result = poll_fn(|ctx| { + let new_ptr = crate::CB.with(|cb| (cb.borrow().1)(ptr)); + let result = f.as_mut().poll(ctx); + crate::CB.with(|cb| (cb.borrow().2)(new_ptr)); result - } else { - f.await - } - }), - } + }) + .await; + crate::CB.with(|cb| (cb.borrow().3)(ptr)); + result + } else { + f.await + } + }); + + Task { task: Some(task) } } /// Executes a future on the current thread. This does not create a new Arbiter @@ -300,7 +327,9 @@ mod asyncstd { /// /// This function panics if ntex system is not running. #[inline] - pub fn spawn_fn(f: F) -> JoinHandle + #[doc(hidden)] + #[deprecated] + pub fn spawn_fn(f: F) -> Task where F: FnOnce() -> R + 'static, R: Future + 'static, @@ -308,18 +337,32 @@ mod asyncstd { spawn(async move { f().await }) } - /// Spawns a blocking task. - /// - /// The task will be spawned onto a thread pool specifically dedicated - /// to blocking tasks. This is useful to prevent long-running synchronous - /// operations from blocking the main futures executor. - pub fn spawn_blocking(f: F) -> JoinHandle - where - F: FnOnce() -> T + Send + 'static, - T: Send + 'static, - { - JoinHandle { - fut: async_std::task::spawn_blocking(f), + /// A spawned task. + pub struct Task { + task: Option>, + } + + impl Task { + pub fn is_finished(&self) -> bool { + if let Some(hnd) = &self.task { + hnd.is_finished() + } else { + true + } + } + } + + impl Drop for Task { + fn drop(&mut self) { + self.task.take().unwrap().detach(); + } + } + + impl Future for Task { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Poll::Ready(Ok(ready!(Pin::new(self.task.as_mut().unwrap()).poll(cx)))) } } @@ -335,128 +378,24 @@ mod asyncstd { impl std::error::Error for JoinError {} pub struct JoinHandle { - fut: async_std::task::JoinHandle, + fut: Option>, + } + + impl JoinHandle { + pub fn is_finished(&self) -> bool { + self.fut.is_none() + } } impl Future for JoinHandle { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Poll::Ready(Ok(ready!(Pin::new(&mut self.fut).poll(cx)))) - } - } -} - -#[allow(dead_code)] -#[cfg(all(feature = "glommio", target_os = "linux"))] -mod glommio { - use std::future::{poll_fn, Future}; - use std::{pin::Pin, task::Context, task::Poll}; - - use futures_channel::oneshot::Canceled; - use glomm_io::task; - - pub type JoinError = Canceled; - - /// Runs the provided future, blocking the current thread until the future - /// completes. - pub fn block_on>(fut: F) { - let ex = glomm_io::LocalExecutor::default(); - ex.run(async move { - let _ = fut.await; - }) - } - - /// Spawn a future on the current thread. This does not create a new Arbiter - /// or Arbiter address, it is simply a helper for spawning futures on the current - /// thread. - /// - /// # Panics - /// - /// This function panics if ntex system is not running. - #[inline] - pub fn spawn(mut f: F) -> JoinHandle - where - F: Future + 'static, - F::Output: 'static, - { - let ptr = crate::CB.with(|cb| (cb.borrow().0)()); - JoinHandle { - fut: Either::Left( - glomm_io::spawn_local(async move { - if let Some(ptr) = ptr { - glomm_io::executor().yield_now().await; - let mut f = unsafe { Pin::new_unchecked(&mut f) }; - let result = poll_fn(|ctx| { - let new_ptr = crate::CB.with(|cb| (cb.borrow().1)(ptr)); - let result = f.as_mut().poll(ctx); - crate::CB.with(|cb| (cb.borrow().2)(new_ptr)); - result - }) - .await; - crate::CB.with(|cb| (cb.borrow().3)(ptr)); - result - } else { - glomm_io::executor().yield_now().await; - f.await - } - }) - .detach(), - ), - } - } - - /// Executes a future on the current thread. This does not create a new Arbiter - /// or Arbiter address, it is simply a helper for executing futures on the current - /// thread. - /// - /// # Panics - /// - /// This function panics if ntex system is not running. - #[inline] - pub fn spawn_fn(f: F) -> JoinHandle - where - F: FnOnce() -> R + 'static, - R: Future + 'static, - { - spawn(async move { f().await }) - } - - pub fn spawn_blocking(f: F) -> JoinHandle - where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, - { - let fut = glomm_io::executor().spawn_blocking(f); - JoinHandle { - fut: Either::Right(Box::pin(async move { Ok(fut.await) })), - } - } - - enum Either { - Left(T1), - Right(T2), - } - - /// Blocking operation completion future. It resolves with results - /// of blocking function execution. - #[allow(clippy::type_complexity)] - pub struct JoinHandle { - fut: - Either, Pin>>>>, - } - - impl Future for JoinHandle { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.fut { - Either::Left(ref mut f) => match Pin::new(f).poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(res) => Poll::Ready(res.ok_or(Canceled)), - }, - Either::Right(ref mut f) => Pin::new(f).poll(cx), - } + Poll::Ready( + ready!(Pin::new(self.fut.as_mut().unwrap()).poll(cx)) + .map_err(|_| JoinError) + .and_then(|result| result.map_err(|_| JoinError)), + ) } } } @@ -464,22 +403,14 @@ mod glommio { #[cfg(feature = "tokio")] pub use self::tokio::*; -#[cfg(feature = "async-std")] -pub use self::asyncstd::*; - -#[cfg(feature = "glommio")] -pub use self::glommio::*; - #[cfg(feature = "compio")] pub use self::compio::*; +#[cfg(feature = "neon")] +pub use self::neon::*; + #[allow(dead_code)] -#[cfg(all( - not(feature = "tokio"), - not(feature = "async-std"), - not(feature = "compio"), - not(feature = "glommio") -))] +#[cfg(all(not(feature = "tokio"), not(feature = "compio"), not(feature = "neon")))] mod no_rt { use std::task::{Context, Poll}; use std::{fmt, future::Future, marker::PhantomData, pin::Pin}; @@ -538,10 +469,5 @@ mod no_rt { impl std::error::Error for JoinError {} } -#[cfg(all( - not(feature = "tokio"), - not(feature = "async-std"), - not(feature = "compio"), - not(feature = "glommio") -))] +#[cfg(all(not(feature = "tokio"), not(feature = "compio"), not(feature = "neon")))] pub use self::no_rt::*; diff --git a/ntex-rt/src/system.rs b/ntex-rt/src/system.rs index 86b783ad..257f81ed 100644 --- a/ntex-rt/src/system.rs +++ b/ntex-rt/src/system.rs @@ -1,13 +1,31 @@ +use std::collections::{HashMap, VecDeque}; use std::sync::{atomic::AtomicUsize, atomic::Ordering, Arc}; +use std::time::{Duration, Instant}; use std::{cell::RefCell, fmt, future::Future, pin::Pin, rc::Rc}; -use async_channel::Sender; +use async_channel::{Receiver, Sender}; +use futures_timer::Delay; -use super::arbiter::{Arbiter, SystemCommand}; +use super::arbiter::Arbiter; use super::builder::{Builder, SystemRunner}; static SYSTEM_COUNT: AtomicUsize = AtomicUsize::new(0); +thread_local!( + static ARBITERS: RefCell = RefCell::new(Arbiters::default()); + static PINGS: RefCell>> = + RefCell::new(HashMap::default()); +); + +#[derive(Default)] +struct Arbiters { + all: HashMap, + list: Vec, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct Id(pub(crate) usize); + /// System is a runtime manager. #[derive(Clone, Debug)] pub struct System { @@ -33,14 +51,17 @@ impl System { /// Constructs new system and sets it as current pub(super) fn construct( sys: Sender, - arbiter: Arbiter, + mut arbiter: Arbiter, config: SystemConfig, ) -> Self { + let id = SYSTEM_COUNT.fetch_add(1, Ordering::SeqCst); + arbiter.sys_id = id; + let sys = System { + id, sys, config, arbiter, - id: SYSTEM_COUNT.fetch_add(1, Ordering::SeqCst), }; System::set_current(sys.clone()); sys @@ -79,8 +100,8 @@ impl System { } /// System id - pub fn id(&self) -> usize { - self.id + pub fn id(&self) -> Id { + Id(self.id) } /// Stop the system @@ -104,6 +125,34 @@ impl System { &self.arbiter } + /// Retrieves a list of all arbiters in the system. + /// + /// This method should be called from the thread where the system has been initialized, + /// typically the "main" thread. + pub fn list_arbiters(f: F) -> R + where + F: FnOnce(&[Arbiter]) -> R, + { + ARBITERS.with(|arbs| f(arbs.borrow().list.as_ref())) + } + + /// Retrieves a list of last pings records for specified arbiter. + /// + /// This method should be called from the thread where the system has been initialized, + /// typically the "main" thread. + pub fn list_arbiter_pings(id: Id, f: F) -> R + where + F: FnOnce(Option<&VecDeque>) -> R, + { + PINGS.with(|pings| { + if let Some(recs) = pings.borrow().get(&id) { + f(Some(recs)) + } else { + f(None) + } + }) + } + pub(super) fn sys(&self) -> &Sender { &self.sys } @@ -150,3 +199,173 @@ impl fmt::Debug for SystemConfig { .finish() } } + +#[derive(Debug)] +pub(super) enum SystemCommand { + Exit(i32), + RegisterArbiter(Id, Arbiter), + UnregisterArbiter(Id), +} + +pub(super) struct SystemSupport { + stop: Option>, + commands: Receiver, + ping_interval: Duration, +} + +impl SystemSupport { + pub(super) fn new( + stop: oneshot::Sender, + commands: Receiver, + ping_interval: usize, + ) -> Self { + Self { + commands, + stop: Some(stop), + ping_interval: Duration::from_millis(ping_interval as u64), + } + } + + pub(super) async fn run(mut self) { + ARBITERS.with(move |arbs| { + let mut arbiters = arbs.borrow_mut(); + arbiters.all.clear(); + arbiters.list.clear(); + }); + + loop { + match self.commands.recv().await { + Ok(SystemCommand::Exit(code)) => { + log::debug!("Stopping system with {} code", code); + + // stop arbiters + ARBITERS.with(move |arbs| { + let mut arbiters = arbs.borrow_mut(); + for arb in arbiters.list.drain(..) { + arb.stop(); + } + arbiters.all.clear(); + }); + + // stop event loop + if let Some(stop) = self.stop.take() { + let _ = stop.send(code); + } + } + Ok(SystemCommand::RegisterArbiter(id, hnd)) => { + crate::spawn(ping_arbiter(hnd.clone(), self.ping_interval)); + ARBITERS.with(move |arbs| { + let mut arbiters = arbs.borrow_mut(); + arbiters.all.insert(id, hnd.clone()); + arbiters.list.push(hnd); + }); + } + Ok(SystemCommand::UnregisterArbiter(id)) => { + ARBITERS.with(move |arbs| { + let mut arbiters = arbs.borrow_mut(); + if let Some(hnd) = arbiters.all.remove(&id) { + for (idx, arb) in arbiters.list.iter().enumerate() { + if &hnd == arb { + arbiters.list.remove(idx); + break; + } + } + } + }); + } + Err(_) => { + log::debug!("System stopped"); + return; + } + } + } + } +} + +#[derive(Copy, Clone, Debug)] +pub struct PingRecord { + /// Ping start time + pub start: Instant, + /// Round-trip time, if value is not set then ping is in process + pub rtt: Option, +} + +async fn ping_arbiter(arb: Arbiter, interval: Duration) { + loop { + Delay::new(interval).await; + + // check if arbiter is still active + let is_alive = ARBITERS.with(|arbs| arbs.borrow().all.contains_key(&arb.id())); + + if !is_alive { + PINGS.with(|pings| pings.borrow_mut().remove(&arb.id())); + break; + } + + // calc ttl + let start = Instant::now(); + PINGS.with(|pings| { + let mut p = pings.borrow_mut(); + let recs = p.entry(arb.id()).or_default(); + recs.push_front(PingRecord { start, rtt: None }); + recs.truncate(10); + }); + + let result = arb + .spawn_with(|| async { + yield_to().await; + }) + .await; + + if result.is_err() { + break; + } + + PINGS.with(|pings| { + pings + .borrow_mut() + .get_mut(&arb.id()) + .unwrap() + .front_mut() + .unwrap() + .rtt = Some(Instant::now() - start); + }); + } +} + +async fn yield_to() { + use std::task::{Context, Poll}; + + struct Yield { + completed: bool, + } + + impl Future for Yield { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + if self.completed { + return Poll::Ready(()); + } + self.completed = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + Yield { completed: false }.await; +} + +pub(super) trait FnExec: Send + 'static { + fn call_box(self: Box); +} + +impl FnExec for F +where + F: FnOnce() + Send + 'static, +{ + #[allow(clippy::boxed_local)] + fn call_box(self: Box) { + (*self)() + } +} diff --git a/ntex-server/CHANGES.md b/ntex-server/CHANGES.md index c01352e4..546a92ff 100644 --- a/ntex-server/CHANGES.md +++ b/ntex-server/CHANGES.md @@ -1,5 +1,33 @@ # Changes +## [2.7.3] - 2025-03-28 + +* Better worker availability handling + +## [2.7.2] - 2025-03-27 + +* Handle paused state + +## [2.7.1] - 2025-02-28 + +* Fix set core affinity out of worker start #508 + +## [2.7.0] - 2025-01-31 + +* Cpu affinity support for workers + +## [2.6.2] - 2024-12-30 + +* Fix error log + +## [2.6.1] - 2024-12-26 + +* Tune shutdown logging + +## [2.6.0] - 2024-12-04 + +* Use updated Service trait + ## [2.5.0] - 2024-11-04 * Use updated Service trait diff --git a/ntex-server/Cargo.toml b/ntex-server/Cargo.toml index 4ada3557..a88be635 100644 --- a/ntex-server/Cargo.toml +++ b/ntex-server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-server" -version = "2.5.0" +version = "2.7.4" authors = ["ntex contributors "] description = "Server for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -18,16 +18,17 @@ path = "src/lib.rs" [dependencies] ntex-bytes = "0.1" ntex-net = "2" -ntex-service = "3.3" +ntex-service = "3.4" ntex-rt = "0.4" -ntex-util = "2.5" +ntex-util = "2.8" -async-channel = "2" -async-broadcast = "0.7" -polling = "3.3" -log = "0.4" -socket2 = "0.5" -oneshot = { version = "0.1", default-features = false, features = ["async"] } +async-channel = { workspace = true } +atomic-waker = { workspace = true } +core_affinity = { workspace = true } +oneshot = { workspace = true } +polling = { workspace = true } +log = { workspace = true } +socket2 = { workspace = true } [dev-dependencies] ntex = "2" diff --git a/ntex-server/src/manager.rs b/ntex-server/src/manager.rs index e9c736c3..9d0bfe8d 100644 --- a/ntex-server/src/manager.rs +++ b/ntex-server/src/manager.rs @@ -2,6 +2,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::{cell::Cell, cell::RefCell, collections::VecDeque, rc::Rc, sync::Arc}; use async_channel::{unbounded, Receiver, Sender}; +use core_affinity::CoreId; use ntex_rt::System; use ntex_util::future::join_all; use ntex_util::time::{sleep, timeout, Millis}; @@ -69,9 +70,16 @@ impl ServerManager { // handle cmd let _ = ntex_rt::spawn(handle_cmd(mgr.clone(), rx)); + // Retrieve the IDs of all active CPU cores. + let mut cores = if cfg.affinity { + core_affinity::get_core_ids().unwrap_or_default() + } else { + Vec::new() + }; + // start workers for _ in 0..mgr.0.cfg.num { - start_worker(mgr.clone()); + start_worker(mgr.clone(), cores.pop()); } let srv = Server::new(tx, shared); @@ -128,10 +136,10 @@ impl ServerManager { } } -fn start_worker(mgr: ServerManager) { +fn start_worker(mgr: ServerManager, cid: Option) { let _ = ntex_rt::spawn(async move { let id = mgr.next_id(); - let mut wrk = Worker::start(id, mgr.factory()); + let mut wrk = Worker::start(id, mgr.factory(), cid); loop { match wrk.status() { @@ -141,7 +149,7 @@ fn start_worker(mgr: ServerManager) { mgr.unavailable(wrk); sleep(RESTART_DELAY).await; if !mgr.stopping() { - wrk = Worker::start(id, mgr.factory()); + wrk = Worker::start(id, mgr.factory(), cid); } else { return; } @@ -172,7 +180,7 @@ impl HandleCmdState { fn process(&mut self, mut item: F::Item) { loop { if !self.workers.is_empty() { - if self.next > self.workers.len() { + if self.next >= self.workers.len() { self.next = self.workers.len() - 1; } match self.workers[self.next].send(item) { @@ -203,10 +211,9 @@ impl HandleCmdState { match upd { Update::Available(worker) => { self.workers.push(worker); + self.workers.sort(); if self.workers.len() == 1 { self.mgr.resume(); - } else { - self.workers.sort(); } } Update::Unavailable(worker) => { @@ -225,6 +232,9 @@ impl HandleCmdState { if let Err(item) = self.workers[0].send(item) { self.backlog.push_back(item); self.workers.remove(0); + if self.workers.is_empty() { + self.mgr.pause(); + } break; } } @@ -262,10 +272,10 @@ impl HandleCmdState { for tx in notify { let _ = tx.send(()); } + sleep(STOP_DELAY).await; // stop system if server was spawned if self.mgr.0.cfg.stop_runtime { - sleep(STOP_DELAY).await; System::current().stop(); } } diff --git a/ntex-server/src/net/accept.rs b/ntex-server/src/net/accept.rs index 4d1d5a85..7694d286 100644 --- a/ntex-server/src/net/accept.rs +++ b/ntex-server/src/net/accept.rs @@ -92,12 +92,14 @@ impl AcceptLoop { /// Start accept loop pub fn start(mut self, socks: Vec<(Token, Listener)>, srv: Server) { + let (tx, rx_start) = oneshot::channel(); let (rx, poll) = self .inner .take() .expect("AcceptLoop cannot be used multiple times"); Accept::start( + tx, rx, poll, socks, @@ -105,6 +107,8 @@ impl AcceptLoop { self.notify.clone(), self.status_handler.take(), ); + + let _ = rx_start.recv(); } } @@ -121,6 +125,7 @@ impl fmt::Debug for AcceptLoop { struct Accept { poller: Arc, rx: mpsc::Receiver, + tx: Option>, sockets: Vec, srv: Server, notify: AcceptNotify, @@ -131,6 +136,7 @@ struct Accept { impl Accept { fn start( + tx: oneshot::Sender<()>, rx: mpsc::Receiver, poller: Arc, socks: Vec<(Token, Listener)>, @@ -145,11 +151,12 @@ impl Accept { .name("ntex-server accept loop".to_owned()) .spawn(move || { System::set_current(sys); - Accept::new(rx, poller, socks, srv, notify, status_handler).poll() + Accept::new(tx, rx, poller, socks, srv, notify, status_handler).poll() }); } fn new( + tx: oneshot::Sender<()>, rx: mpsc::Receiver, poller: Arc, socks: Vec<(Token, Listener)>, @@ -175,6 +182,7 @@ impl Accept { notify, srv, status_handler, + tx: Some(tx), backpressure: true, backlog: VecDeque::new(), } @@ -192,19 +200,23 @@ impl Accept { // Create storage for events let mut events = Events::with_capacity(NonZeroUsize::new(512).unwrap()); + let mut timeout = Some(Duration::ZERO); loop { - if let Err(e) = self.poller.wait(&mut events, None) { - if e.kind() == io::ErrorKind::Interrupted { - continue; - } else { + if let Err(e) = self.poller.wait(&mut events, timeout) { + if e.kind() != io::ErrorKind::Interrupted { panic!("Cannot wait for events in poller: {}", e) } + } else if timeout.is_some() { + timeout = None; + let _ = self.tx.take().unwrap().send(()); } - for event in events.iter() { - let readd = self.accept(event.key); - if readd { - self.add_source(event.key); + for idx in 0..self.sockets.len() { + if self.sockets[idx].registered.get() { + let readd = self.accept(idx); + if readd { + self.add_source(idx); + } } } @@ -215,13 +227,13 @@ impl Accept { for info in self.sockets.drain(..) { info.sock.remove_source() } + log::info!("Accept loop has been stopped"); if let Some(rx) = rx { thread::sleep(EXIT_TIMEOUT); let _ = rx.send(()); } - log::trace!("Accept loop has been stopped"); break; } } @@ -295,25 +307,25 @@ impl Accept { Ok(cmd) => match cmd { AcceptorCommand::Stop(rx) => { if !self.backpressure { - log::trace!("Stopping accept loop"); + log::info!("Stopping accept loop"); self.backpressure(true); } break Either::Right(Some(rx)); } AcceptorCommand::Terminate => { - log::trace!("Stopping accept loop"); + log::info!("Stopping accept loop"); self.backpressure(true); break Either::Right(None); } AcceptorCommand::Pause => { if !self.backpressure { - log::trace!("Pausing accept loop"); + log::info!("Pausing accept loop"); self.backpressure(true); } } AcceptorCommand::Resume => { if self.backpressure { - log::trace!("Resuming accept loop"); + log::info!("Resuming accept loop"); self.backpressure(false); } } @@ -325,10 +337,11 @@ impl Accept { break match err { mpsc::TryRecvError::Empty => Either::Left(()), mpsc::TryRecvError::Disconnected => { + log::error!("Dropping accept loop"); self.backpressure(true); Either::Right(None) } - } + }; } } } diff --git a/ntex-server/src/net/builder.rs b/ntex-server/src/net/builder.rs index 95be84b9..9b11684a 100644 --- a/ntex-server/src/net/builder.rs +++ b/ntex-server/src/net/builder.rs @@ -110,6 +110,14 @@ impl ServerBuilder { self } + /// Enable cpu affinity + /// + /// By default affinity is disabled. + pub fn enable_affinity(mut self) -> Self { + self.pool = self.pool.enable_affinity(); + self + } + /// Timeout for graceful workers shutdown. /// /// After receiving a stop signal, workers have this much time to finish @@ -360,7 +368,7 @@ pub fn bind_addr( Err(e) } else { Err(io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::InvalidInput, "Cannot bind to address.", )) } diff --git a/ntex-server/src/net/service.rs b/ntex-server/src/net/service.rs index 4be6c828..70e9c5e3 100644 --- a/ntex-server/src/net/service.rs +++ b/ntex-server/src/net/service.rs @@ -1,4 +1,4 @@ -use std::{fmt, future::poll_fn, future::Future, pin::Pin, task::Poll}; +use std::{fmt, task::Context}; use ntex_bytes::{Pool, PoolRef}; use ntex_net::Io; @@ -170,27 +170,11 @@ impl Service for StreamServiceImpl { } #[inline] - async fn not_ready(&self) { - if self.conns.is_available() { - let mut futs: Vec<_> = self - .services - .iter() - .map(|s| Box::pin(s.not_ready())) - .collect(); - - ntex_util::future::select( - self.conns.unavailable(), - poll_fn(move |cx| { - for f in &mut futs { - if Pin::new(f).poll(cx).is_ready() { - return Poll::Ready(()); - } - } - Poll::Pending - }), - ) - .await; + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + for svc in &self.services { + svc.poll(cx)?; } + Ok(()) } async fn shutdown(&self) { diff --git a/ntex-server/src/net/test.rs b/ntex-server/src/net/test.rs index 80441628..1c78f5c5 100644 --- a/ntex-server/src/net/test.rs +++ b/ntex-server/src/net/test.rs @@ -59,8 +59,13 @@ where .workers(1) .disable_signals() .run(); - tx.send((system, local_addr, server)) - .expect("Failed to send Server to TestServer"); + + ntex_rt::spawn(async move { + ntex_util::time::sleep(ntex_util::time::Millis(75)).await; + tx.send((system, local_addr, server)) + .expect("Failed to send Server to TestServer"); + }); + Ok(()) }) }); diff --git a/ntex-server/src/pool.rs b/ntex-server/src/pool.rs index 229ea8ba..d1e76c59 100644 --- a/ntex-server/src/pool.rs +++ b/ntex-server/src/pool.rs @@ -11,6 +11,7 @@ pub struct WorkerPool { pub(crate) no_signals: bool, pub(crate) stop_runtime: bool, pub(crate) shutdown_timeout: Millis, + pub(crate) affinity: bool, } impl Default for WorkerPool { @@ -22,12 +23,18 @@ impl Default for WorkerPool { impl WorkerPool { /// Create new Server builder instance pub fn new() -> Self { + let num = core_affinity::get_core_ids() + .map(|v| v.len()) + .unwrap_or_else(|| { + std::thread::available_parallelism().map_or(2, std::num::NonZeroUsize::get) + }); + WorkerPool { - num: std::thread::available_parallelism() - .map_or(2, std::num::NonZeroUsize::get), + num, no_signals: false, stop_runtime: false, shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT, + affinity: false, } } @@ -68,6 +75,14 @@ impl WorkerPool { self } + /// Enable core affinity + /// + /// By default affinity is disabled. + pub fn enable_affinity(mut self) -> Self { + self.affinity = true; + self + } + /// Starts processing incoming items and return server controller. pub fn run(self, factory: F) -> Server { crate::manager::ServerManager::start(self, factory) diff --git a/ntex-server/src/wrk.rs b/ntex-server/src/wrk.rs index 17ec56ae..b791817d 100644 --- a/ntex-server/src/wrk.rs +++ b/ntex-server/src/wrk.rs @@ -2,8 +2,9 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::task::{ready, Context, Poll}; use std::{cmp, future::poll_fn, future::Future, hash, pin::Pin, sync::Arc}; -use async_broadcast::{self as bus, broadcast}; use async_channel::{unbounded, Receiver, Sender}; +use atomic_waker::AtomicWaker; +use core_affinity::CoreId; use ntex_rt::{spawn, Arbiter}; use ntex_service::{Pipeline, PipelineBinding, Service, ServiceFactory}; @@ -77,7 +78,7 @@ pub struct WorkerStop(oneshot::Receiver); impl Worker { /// Start worker. - pub fn start(id: WorkerId, cfg: F) -> Worker + pub fn start(id: WorkerId, cfg: F, cid: Option) -> Worker where T: Send + 'static, F: ServerConfiguration, @@ -87,15 +88,21 @@ impl Worker { let (avail, avail_tx) = WorkerAvailability::create(); Arbiter::default().exec_fn(move || { + if let Some(cid) = cid { + if core_affinity::set_for_current(cid) { + log::info!("Set affinity to {:?} for worker {:?}", cid, id); + } + } + let _ = spawn(async move { log::info!("Starting worker {:?}", id); log::debug!("Creating server instance in {:?}", id); let factory = cfg.create().await; - log::debug!("Server instance has been created in {:?}", id); match create(id, rx1, rx2, factory, avail_tx).await { Ok((svc, wrk)) => { + log::debug!("Server instance has been created in {:?}", id); run_worker(svc, wrk).await; } Err(e) => { @@ -144,10 +151,8 @@ impl Worker { if self.failed.load(Ordering::Acquire) { WorkerStatus::Failed } else { - // cleanup updates - while self.avail.notify.try_recv().is_ok() {} - - if self.avail.notify.recv_direct().await.is_err() { + self.avail.wait_for_update().await; + if self.avail.failed() { self.failed.store(true, Ordering::Release); } self.status() @@ -189,52 +194,85 @@ impl Future for WorkerStop { #[derive(Debug, Clone)] struct WorkerAvailability { - notify: bus::Receiver<()>, - available: Arc, + inner: Arc, } #[derive(Debug, Clone)] struct WorkerAvailabilityTx { - notify: bus::Sender<()>, - available: Arc, + inner: Arc, +} + +#[derive(Debug)] +struct Inner { + waker: AtomicWaker, + updated: AtomicBool, + available: AtomicBool, + failed: AtomicBool, } impl WorkerAvailability { fn create() -> (Self, WorkerAvailabilityTx) { - let (mut tx, rx) = broadcast(16); - tx.set_overflow(true); + let inner = Arc::new(Inner { + waker: AtomicWaker::new(), + updated: AtomicBool::new(false), + available: AtomicBool::new(false), + failed: AtomicBool::new(false), + }); let avail = WorkerAvailability { - notify: rx, - available: Arc::new(AtomicBool::new(false)), - }; - let avail_tx = WorkerAvailabilityTx { - notify: tx, - available: avail.available.clone(), + inner: inner.clone(), }; + let avail_tx = WorkerAvailabilityTx { inner }; (avail, avail_tx) } + fn failed(&self) -> bool { + self.inner.failed.load(Ordering::Acquire) + } + fn available(&self) -> bool { - self.available.load(Ordering::Acquire) + self.inner.available.load(Ordering::Acquire) + } + + async fn wait_for_update(&self) { + poll_fn(|cx| { + if self.inner.updated.load(Ordering::Acquire) { + self.inner.updated.store(false, Ordering::Release); + Poll::Ready(()) + } else { + self.inner.waker.register(cx.waker()); + Poll::Pending + } + }) + .await; } } impl WorkerAvailabilityTx { fn set(&self, val: bool) { - let old = self.available.swap(val, Ordering::Release); - if !old && val { - let _ = self.notify.try_broadcast(()); + let old = self.inner.available.swap(val, Ordering::Release); + if old != val { + self.inner.updated.store(true, Ordering::Release); + self.inner.waker.wake(); } } } +impl Drop for WorkerAvailabilityTx { + fn drop(&mut self) { + self.inner.failed.store(true, Ordering::Release); + self.inner.updated.store(true, Ordering::Release); + self.inner.available.store(false, Ordering::Release); + self.inner.waker.wake(); + } +} + /// Service worker /// /// Worker accepts message via unbounded channel and starts processing. struct WorkerSt> { id: WorkerId, - rx: Pin>>, + rx: Receiver, stop: Pin>>, factory: F, availability: WorkerAvailabilityTx, @@ -246,25 +284,43 @@ where F: ServiceFactory + 'static, { loop { + let mut recv = std::pin::pin!(wrk.rx.recv()); let fut = poll_fn(|cx| { - ready!(svc.poll_ready(cx)?); - - if let Some(item) = ready!(Pin::new(&mut wrk.rx).poll_next(cx)) { - let fut = svc.call(item); - let _ = spawn(async move { - let _ = fut.await; - }); + match svc.poll_ready(cx) { + Poll::Ready(Ok(())) => { + wrk.availability.set(true); + } + Poll::Ready(Err(err)) => { + wrk.availability.set(false); + return Poll::Ready(Err(err)); + } + Poll::Pending => { + wrk.availability.set(false); + return Poll::Pending; + } + } + + match ready!(recv.as_mut().poll(cx)) { + Ok(item) => { + let fut = svc.call(item); + let _ = spawn(async move { + let _ = fut.await; + }); + Poll::Ready(Ok::<_, F::Error>(true)) + } + Err(_) => { + log::error!("Server is gone"); + Poll::Ready(Ok(false)) + } } - Poll::Ready(Ok::<(), F::Error>(())) }); match select(fut, stream_recv(&mut wrk.stop)).await { - Either::Left(Ok(())) => continue, + Either::Left(Ok(true)) => continue, Either::Left(Err(_)) => { let _ = ntex_rt::spawn(async move { svc.shutdown().await; }); - wrk.availability.set(false); } Either::Right(Some(Shutdown { timeout, result })) => { wrk.availability.set(false); @@ -278,7 +334,8 @@ where stop_svc(wrk.id, svc, timeout, Some(result)).await; return; } - Either::Right(None) => { + Either::Left(Ok(false)) | Either::Right(None) => { + wrk.availability.set(false); stop_svc(wrk.id, svc, STOP_TIMEOUT, None).await; return; } @@ -288,7 +345,6 @@ where loop { match select(wrk.factory.create(()), stream_recv(&mut wrk.stop)).await { Either::Left(Ok(service)) => { - wrk.availability.set(true); svc = Pipeline::new(service).bind(); break; } @@ -329,8 +385,6 @@ where { availability.set(false); let factory = factory?; - - let rx = Box::pin(rx); let mut stop = Box::pin(stop); let svc = match select(factory.create(()), stream_recv(&mut stop)).await { @@ -349,9 +403,9 @@ where svc, WorkerSt { id, + rx, factory, availability, - rx: Box::pin(rx), stop: Box::pin(stop), }, )) diff --git a/ntex-service/CHANGES.md b/ntex-service/CHANGES.md index 53d3b3cf..b6bc5503 100644 --- a/ntex-service/CHANGES.md +++ b/ntex-service/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [3.4.0] - 2024-12-04 + +* Added Service::poll() method + ## [3.3.3] - 2024-11-10 * Add Pipeline::is_shutdown() helper diff --git a/ntex-service/Cargo.toml b/ntex-service/Cargo.toml index f23fccd9..f338931d 100644 --- a/ntex-service/Cargo.toml +++ b/ntex-service/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-service" -version = "3.3.3" +version = "3.4.0" authors = ["ntex contributors "] description = "ntex service" keywords = ["network", "framework", "async", "futures"] diff --git a/ntex-service/src/and_then.rs b/ntex-service/src/and_then.rs index 494a8072..2e9e9f2f 100644 --- a/ntex-service/src/and_then.rs +++ b/ntex-service/src/and_then.rs @@ -31,8 +31,9 @@ where } #[inline] - async fn not_ready(&self) { - util::select(self.svc1.not_ready(), self.svc2.not_ready()).await + fn poll(&self, cx: &mut std::task::Context<'_>) -> Result<(), Self::Error> { + self.svc1.poll(cx)?; + self.svc2.poll(cx) } #[inline] @@ -88,8 +89,8 @@ where #[cfg(test)] mod tests { - use ntex_util::time; - use std::{cell::Cell, rc::Rc}; + use ntex_util::future::lazy; + use std::{cell::Cell, rc::Rc, task::Context}; use crate::{chain, chain_factory, fn_factory, Service, ServiceCtx}; @@ -105,9 +106,9 @@ mod tests { Ok(()) } - async fn not_ready(&self) { + fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> { self.0.set(self.0.get() + 1); - std::future::pending().await + Ok(()) } async fn call( @@ -135,9 +136,9 @@ mod tests { Ok(()) } - async fn not_ready(&self) { + fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> { self.0.set(self.0.get() + 1); - std::future::pending().await + Ok(()) } async fn call( @@ -165,11 +166,7 @@ mod tests { assert_eq!(res, Ok(())); assert_eq!(cnt.get(), 2); - let srv2 = srv.clone(); - ntex::rt::spawn(async move { - let _ = srv2.not_ready().await; - }); - time::sleep(time::Millis(25)).await; + lazy(|cx| srv.clone().poll(cx)).await.unwrap(); assert_eq!(cnt.get(), 4); srv.shutdown().await; diff --git a/ntex-service/src/apply.rs b/ntex-service/src/apply.rs index 43640e0c..84fc153a 100644 --- a/ntex-service/src/apply.rs +++ b/ntex-service/src/apply.rs @@ -113,7 +113,7 @@ where (self.f)(req, self.service.clone()).await } - crate::forward_notready!(service); + crate::forward_poll!(service); crate::forward_shutdown!(service); } @@ -205,7 +205,8 @@ where #[cfg(test)] mod tests { - use std::{cell::Cell, rc::Rc}; + use ntex_util::future::lazy; + use std::{cell::Cell, rc::Rc, task::Context}; use super::*; use crate::{chain, chain_factory, fn_factory}; @@ -221,8 +222,9 @@ mod tests { Ok(()) } - async fn not_ready(&self) { + fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> { self.0.set(self.0.get() + 1); + Ok(()) } async fn shutdown(&self) { @@ -253,7 +255,7 @@ mod tests { assert_eq!(srv.ready().await, Ok::<_, Err>(())); - srv.not_ready().await; + lazy(|cx| srv.poll(cx)).await.unwrap(); assert_eq!(cnt_sht.get(), 1); srv.shutdown().await; diff --git a/ntex-service/src/boxed.rs b/ntex-service/src/boxed.rs index 7e63fbea..6998ce61 100644 --- a/ntex-service/src/boxed.rs +++ b/ntex-service/src/boxed.rs @@ -1,4 +1,4 @@ -use std::{fmt, future::Future, pin::Pin}; +use std::{fmt, future::Future, pin::Pin, task::Context}; use crate::ctx::{ServiceCtx, WaitersRef}; @@ -54,8 +54,6 @@ trait ServiceObj { waiters: &'a WaitersRef, ) -> BoxFuture<'a, (), Self::Error>; - fn not_ready<'a>(&'a self) -> Pin + 'a>>; - fn call<'a>( &'a self, req: Req, @@ -64,6 +62,8 @@ trait ServiceObj { ) -> BoxFuture<'a, Self::Response, Self::Error>; fn shutdown<'a>(&'a self) -> Pin + 'a>>; + + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error>; } impl ServiceObj for S @@ -83,11 +83,6 @@ where Box::pin(async move { ServiceCtx::<'a, S>::new(idx, waiters).ready(self).await }) } - #[inline] - fn not_ready<'a>(&'a self) -> Pin + 'a>> { - Box::pin(crate::Service::not_ready(self)) - } - #[inline] fn shutdown<'a>(&'a self) -> Pin + 'a>> { Box::pin(crate::Service::shutdown(self)) @@ -106,6 +101,11 @@ where .await }) } + + #[inline] + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + crate::Service::poll(self, cx) + } } trait ServiceFactoryObj { @@ -158,11 +158,6 @@ where self.0.ready(idx, waiters).await } - #[inline] - async fn not_ready(&self) { - self.0.not_ready().await - } - #[inline] async fn shutdown(&self) { self.0.shutdown().await @@ -173,6 +168,11 @@ where let (idx, waiters) = ctx.inner(); self.0.call(req, idx, waiters).await } + + #[inline] + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + self.0.poll(cx) + } } impl crate::ServiceFactory diff --git a/ntex-service/src/chain.rs b/ntex-service/src/chain.rs index f75c9731..836d88fa 100644 --- a/ntex-service/src/chain.rs +++ b/ntex-service/src/chain.rs @@ -171,6 +171,7 @@ impl, Req> Service for ServiceChain { type Response = Svc::Response; type Error = Svc::Error; + crate::forward_poll!(service); crate::forward_ready!(service); crate::forward_shutdown!(service); diff --git a/ntex-service/src/ctx.rs b/ntex-service/src/ctx.rs index ed4de4ee..dbac0716 100644 --- a/ntex-service/src/ctx.rs +++ b/ntex-service/src/ctx.rs @@ -21,9 +21,6 @@ impl WaitersRef { pub(crate) fn new() -> (u32, Self) { let mut waiters = slab::Slab::new(); - // first insert for wake ups from services - let _ = waiters.insert(None); - ( waiters.insert(Default::default()) as u32, WaitersRef { @@ -68,18 +65,6 @@ impl WaitersRef { self.get()[idx as usize] = Some(cx.waker().clone()); } - pub(crate) fn register_unready(&self, cx: &mut Context<'_>) { - self.get()[0] = Some(cx.waker().clone()); - } - - pub(crate) fn notify_unready(&self) { - if let Some(item) = self.get().get_mut(0) { - if let Some(waker) = item.take() { - waker.wake(); - } - } - } - pub(crate) fn notify(&self) { let wakers = self.get_wakers(); if !wakers.is_empty() { diff --git a/ntex-service/src/lib.rs b/ntex-service/src/lib.rs index 67253013..df444562 100644 --- a/ntex-service/src/lib.rs +++ b/ntex-service/src/lib.rs @@ -6,7 +6,7 @@ unreachable_pub, missing_debug_implementations )] -use std::rc::Rc; +use std::{rc::Rc, task::Context}; mod and_then; mod apply; @@ -118,7 +118,8 @@ pub trait Service { Ok(()) } - #[inline] + #[deprecated] + #[doc(hidden)] /// Returns when the service is not able to process requests. /// /// Unlike the "ready()" method, the "not_ready()" method returns @@ -136,6 +137,15 @@ pub trait Service { /// Returns when the service is properly shutdowns. async fn shutdown(&self) {} + #[inline] + /// Polls service from the current task. + /// + /// Service may require to execute asynchronous computation or + /// maintain asynchronous state. + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + Ok(()) + } + #[inline] /// Map this service's output to a different type, returning a new service of the resulting type. /// @@ -246,7 +256,7 @@ pub trait ServiceFactory { } } -impl<'a, S, Req> Service for &'a S +impl Service for &S where S: Service, { @@ -259,8 +269,8 @@ where } #[inline] - async fn not_ready(&self) { - (**self).not_ready().await + fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> { + (**self).poll(cx) } #[inline] @@ -290,11 +300,6 @@ where ctx.ready(&**self).await } - #[inline] - async fn not_ready(&self) { - (**self).not_ready().await - } - #[inline] async fn shutdown(&self) { (**self).shutdown().await @@ -308,6 +313,11 @@ where ) -> Result { ctx.call_nowait(&**self, request).await } + + #[inline] + fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> { + (**self).poll(cx) + } } impl ServiceFactory for Rc diff --git a/ntex-service/src/macros.rs b/ntex-service/src/macros.rs index d951775d..846efa8d 100644 --- a/ntex-service/src/macros.rs +++ b/ntex-service/src/macros.rs @@ -11,11 +11,6 @@ macro_rules! forward_ready { .await .map_err(::core::convert::Into::into) } - - #[inline] - async fn not_ready(&self) { - self.$field.not_ready().await - } }; ($field:ident, $err:expr) => { #[inline] @@ -25,21 +20,28 @@ macro_rules! forward_ready { ) -> Result<(), Self::Error> { ctx.ready(&self.$field).await.map_err($err) } - - #[inline] - async fn not_ready(&self) { - self.$field.not_ready().await - } }; } /// An implementation of [`not_ready`] that forwards not_ready call to a field. #[macro_export] macro_rules! forward_notready { + ($field:ident) => {}; +} + +/// An implementation of [`poll`] that forwards poll call to a field. +#[macro_export] +macro_rules! forward_poll { ($field:ident) => { #[inline] - async fn not_ready(&self) { - self.$field.not_ready().await + fn poll(&self, cx: &mut std::task::Context<'_>) -> Result<(), Self::Error> { + self.$field.poll(cx).map_err(From::from) + } + }; + ($field:ident, $err:expr) => { + #[inline] + fn poll(&self, cx: &mut std::task::Context<'_>) -> Result<(), Self::Error> { + self.$field.poll(cx).map_err($err) } }; } diff --git a/ntex-service/src/map.rs b/ntex-service/src/map.rs index 7d4cb094..3f1e37ff 100644 --- a/ntex-service/src/map.rs +++ b/ntex-service/src/map.rs @@ -62,6 +62,7 @@ where type Error = A::Error; crate::forward_ready!(service); + crate::forward_poll!(service); crate::forward_shutdown!(service); #[inline] diff --git a/ntex-service/src/map_err.rs b/ntex-service/src/map_err.rs index 544b0f7e..97279c3d 100644 --- a/ntex-service/src/map_err.rs +++ b/ntex-service/src/map_err.rs @@ -1,4 +1,4 @@ -use std::{fmt, marker::PhantomData}; +use std::{fmt, marker::PhantomData, task::Context}; use super::{Service, ServiceCtx, ServiceFactory}; @@ -67,6 +67,11 @@ where ctx.ready(&self.service).await.map_err(&self.f) } + #[inline] + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + self.service.poll(cx).map_err(&self.f) + } + #[inline] async fn call( &self, @@ -77,7 +82,6 @@ where } crate::forward_shutdown!(service); - crate::forward_notready!(service); } /// Factory for the `map_err` combinator, changing the type of a new diff --git a/ntex-service/src/pipeline.rs b/ntex-service/src/pipeline.rs index 243ce885..8ff88942 100644 --- a/ntex-service/src/pipeline.rs +++ b/ntex-service/src/pipeline.rs @@ -1,4 +1,4 @@ -use std::{cell, fmt, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll}; +use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll}; use crate::{ctx::WaitersRef, Service, ServiceCtx}; @@ -50,13 +50,14 @@ impl Pipeline { .await } - #[inline] + #[doc(hidden)] + #[deprecated] /// Returns when the pipeline is not able to process requests. pub async fn not_ready(&self) where S: Service, { - self.state.svc.not_ready().await + std::future::pending().await } #[inline] @@ -125,6 +126,14 @@ impl Pipeline { self.state.svc.shutdown().await } + #[inline] + pub fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> + where + S: Service, + { + self.state.svc.poll(cx) + } + #[inline] /// Get current pipeline. pub fn bind(self) -> PipelineBinding @@ -175,7 +184,6 @@ where { pl: Pipeline, st: cell::UnsafeCell>, - not_ready: cell::UnsafeCell, } enum State { @@ -184,11 +192,6 @@ enum State { Shutdown(Pin + 'static>>), } -enum StateNotReady { - New, - Readiness(Pin>>), -} - impl PipelineBinding where S: Service + 'static, @@ -198,7 +201,6 @@ where PipelineBinding { pl, st: cell::UnsafeCell::new(State::New), - not_ready: cell::UnsafeCell::new(StateNotReady::New), } } @@ -214,6 +216,11 @@ where self.pl.clone() } + #[inline] + pub fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> { + self.pl.poll(cx) + } + #[inline] /// Returns `Ready` when the pipeline is able to process requests. /// @@ -230,6 +237,7 @@ where let fut = Box::pin(CheckReadiness { fut: None, f: ready, + _t: marker::PhantomData, pl, }); *st = State::Readiness(fut); @@ -240,27 +248,12 @@ where } } + #[doc(hidden)] + #[deprecated] #[inline] /// Returns when the pipeline is not able to process requests. - pub fn poll_not_ready(&self, cx: &mut Context<'_>) -> Poll<()> { - let st = unsafe { &mut *self.not_ready.get() }; - - match st { - StateNotReady::New => { - // SAFETY: `fut` has same lifetime same as lifetime of `self.pl`. - // Pipeline::svc is heap allocated(Rc), and it is being kept alive until - // `self` is alive - let pl: &'static Pipeline = unsafe { std::mem::transmute(&self.pl) }; - let fut = Box::pin(CheckUnReadiness { - fut: None, - f: not_ready, - pl, - }); - *st = StateNotReady::Readiness(fut); - self.poll_not_ready(cx) - } - StateNotReady::Readiness(ref mut fut) => Pin::new(fut).poll(cx), - } + pub fn poll_not_ready(&self, _: &mut Context<'_>) -> Poll<()> { + Poll::Pending } #[inline] @@ -276,7 +269,6 @@ where let pl: &'static Pipeline = unsafe { std::mem::transmute(&self.pl) }; *st = State::Shutdown(Box::pin(async move { pl.shutdown().await })); pl.state.waiters.shutdown(); - pl.state.waiters.notify_unready(); self.poll_shutdown(cx) } State::Shutdown(ref mut fut) => Pin::new(fut).poll(cx), @@ -345,7 +337,6 @@ where Self { pl: self.pl.clone(), st: cell::UnsafeCell::new(State::New), - not_ready: cell::UnsafeCell::new(StateNotReady::New), } } } @@ -404,23 +395,16 @@ where .ready(ServiceCtx::<'_, S>::new(pl.index, pl.state.waiters_ref())) } -fn not_ready(pl: &'static Pipeline) -> impl Future -where - S: Service, - R: 'static, -{ - pl.state.svc.not_ready() -} - -struct CheckReadiness { +struct CheckReadiness + 'static, R, F, Fut> { f: F, fut: Option, pl: &'static Pipeline, + _t: marker::PhantomData, } -impl Unpin for CheckReadiness {} +impl, R, F, Fut> Unpin for CheckReadiness {} -impl Drop for CheckReadiness { +impl, R, F, Fut> Drop for CheckReadiness { fn drop(&mut self) { // future fot dropped during polling, we must notify other waiters if self.fut.is_some() { @@ -429,16 +413,19 @@ impl Drop for CheckReadiness { } } -impl Future for CheckReadiness +impl Future for CheckReadiness where + S: Service, F: Fn(&'static Pipeline) -> Fut, - Fut: Future, + Fut: Future>, { - type Output = T; + type Output = Result<(), S::Error>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut slf = self.as_mut(); + slf.pl.poll(cx)?; + if slf.pl.state.waiters.can_check(slf.pl.index, cx) { if slf.fut.is_none() { slf.fut = Some((slf.f)(slf.pl)); @@ -460,43 +447,3 @@ where } } } - -struct CheckUnReadiness { - f: F, - fut: Option, - pl: &'static Pipeline, -} - -impl Unpin for CheckUnReadiness {} - -impl Future for CheckUnReadiness -where - F: Fn(&'static Pipeline) -> Fut, - Fut: Future, -{ - type Output = (); - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - let mut slf = self.as_mut(); - - if slf.fut.is_none() { - slf.fut = Some((slf.f)(slf.pl)); - } - let fut = slf.fut.as_mut().unwrap(); - match unsafe { Pin::new_unchecked(fut) }.poll(cx) { - Poll::Pending => { - if slf.pl.state.waiters.is_shutdown() { - Poll::Ready(()) - } else { - slf.pl.state.waiters.register_unready(cx); - Poll::Pending - } - } - Poll::Ready(()) => { - let _ = slf.fut.take(); - slf.pl.state.waiters.notify(); - Poll::Ready(()) - } - } - } -} diff --git a/ntex-service/src/then.rs b/ntex-service/src/then.rs index 2b733698..6de8a14e 100644 --- a/ntex-service/src/then.rs +++ b/ntex-service/src/then.rs @@ -31,8 +31,9 @@ where } #[inline] - async fn not_ready(&self) { - util::select(self.svc1.not_ready(), self.svc2.not_ready()).await + fn poll(&self, cx: &mut std::task::Context<'_>) -> Result<(), Self::Error> { + self.svc1.poll(cx)?; + self.svc2.poll(cx) } #[inline] @@ -91,8 +92,8 @@ where #[cfg(test)] mod tests { - use ntex_util::time; - use std::{cell::Cell, rc::Rc}; + use ntex_util::future::lazy; + use std::{cell::Cell, rc::Rc, task::Context}; use crate::{chain, chain_factory, fn_factory, Service, ServiceCtx}; @@ -108,9 +109,9 @@ mod tests { Ok(()) } - async fn not_ready(&self) { + fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> { self.0.set(self.0.get() + 1); - std::future::pending().await + Ok(()) } async fn call( @@ -141,9 +142,9 @@ mod tests { Ok(()) } - async fn not_ready(&self) { + fn poll(&self, _: &mut Context<'_>) -> Result<(), Self::Error> { self.0.set(self.0.get() + 1); - std::future::pending().await + Ok(()) } async fn call( @@ -173,11 +174,7 @@ mod tests { assert_eq!(res, Ok(())); assert_eq!(cnt.get(), 2); - let srv2 = srv.clone(); - ntex::rt::spawn(async move { - let _ = srv2.not_ready().await; - }); - time::sleep(time::Millis(25)).await; + lazy(|cx| srv.clone().poll(cx)).await.unwrap(); assert_eq!(cnt.get(), 4); srv.shutdown().await; diff --git a/ntex-service/src/util.rs b/ntex-service/src/util.rs index 4c041a15..5421b96c 100644 --- a/ntex-service/src/util.rs +++ b/ntex-service/src/util.rs @@ -59,24 +59,3 @@ where }) .await } - -/// Waits for either one of two differently-typed futures to complete. -pub(crate) async fn select(fut1: A, fut2: B) -> A::Output -where - A: Future, - B: Future, -{ - let mut fut1 = pin::pin!(fut1); - let mut fut2 = pin::pin!(fut2); - - poll_fn(|cx| { - if let Poll::Ready(item) = Pin::new(&mut fut1).poll(cx) { - return Poll::Ready(item); - } - if let Poll::Ready(item) = Pin::new(&mut fut2).poll(cx) { - return Poll::Ready(item); - } - Poll::Pending - }) - .await -} diff --git a/ntex-tls/CHANGES.md b/ntex-tls/CHANGES.md index 8e417bc4..94d5b547 100644 --- a/ntex-tls/CHANGES.md +++ b/ntex-tls/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.4.0] - 2024-12-30 + +* Enable rustls/std feature + ## [2.3.0] - 2024-11-04 * Use updated Service trait diff --git a/ntex-tls/Cargo.toml b/ntex-tls/Cargo.toml index 4a731964..701128c8 100644 --- a/ntex-tls/Cargo.toml +++ b/ntex-tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-tls" -version = "2.3.0" +version = "2.4.0" authors = ["ntex contributors "] description = "An implementation of SSL streams for ntex backed by OpenSSL" keywords = ["network", "framework", "async", "futures"] @@ -22,14 +22,14 @@ default = [] openssl = ["tls_openssl"] # rustls support -rustls = ["tls_rust"] +rustls = ["tls_rust", "tls_rust/std"] rustls-ring = ["tls_rust", "tls_rust/ring", "tls_rust/std"] [dependencies] ntex-bytes = "0.1" ntex-io = "2.3" ntex-util = "2.5" -ntex-service = "3.3" +ntex-service = "3.4" ntex-net = "2" log = "0.4" diff --git a/ntex-tls/examples/rustls-server.rs b/ntex-tls/examples/rustls-server.rs index 445cffec..a80b25e2 100644 --- a/ntex-tls/examples/rustls-server.rs +++ b/ntex-tls/examples/rustls-server.rs @@ -13,9 +13,8 @@ async fn main() -> io::Result<()> { println!("Started openssl echp server: 127.0.0.1:8443"); // load ssl keys - let cert_file = - &mut BufReader::new(File::open("../ntex-tls/examples/cert.pem").unwrap()); - let key_file = &mut BufReader::new(File::open("../ntex-tls/examples/key.pem").unwrap()); + let cert_file = &mut BufReader::new(File::open("../examples/cert.pem").unwrap()); + let key_file = &mut BufReader::new(File::open("../examples/key.pem").unwrap()); let keys = rustls_pemfile::private_key(key_file).unwrap().unwrap(); let cert_chain = rustls_pemfile::certs(cert_file) .collect::, _>>() diff --git a/ntex-tls/examples/webserver.rs b/ntex-tls/examples/webserver.rs index 52867a6b..9398708e 100644 --- a/ntex-tls/examples/webserver.rs +++ b/ntex-tls/examples/webserver.rs @@ -8,18 +8,18 @@ use tls_openssl::ssl::{self, SslFiletype, SslMethod}; #[ntex::main] async fn main() -> io::Result<()> { - //std::env::set_var("RUST_LOG", "trace"); - //env_logger::init(); + std::env::set_var("RUST_LOG", "trace"); + let _ = env_logger::try_init(); println!("Started openssl web server: 127.0.0.1:8443"); // load ssl keys let mut builder = ssl::SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); builder - .set_private_key_file("../tests/key.pem", SslFiletype::PEM) + .set_private_key_file("./examples/key.pem", SslFiletype::PEM) .unwrap(); builder - .set_certificate_chain_file("../tests/cert.pem") + .set_certificate_chain_file("./examples/cert.pem") .unwrap(); // h2 alpn config diff --git a/ntex-tls/src/openssl/connect.rs b/ntex-tls/src/openssl/connect.rs index 3df5debf..c2ffe528 100644 --- a/ntex-tls/src/openssl/connect.rs +++ b/ntex-tls/src/openssl/connect.rs @@ -51,11 +51,11 @@ impl SslConnector { log::trace!("{}: SSL Handshake start for: {:?}", io.tag(), host); match openssl.configure() { - Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()), + Err(e) => Err(io::Error::new(io::ErrorKind::InvalidInput, e).into()), Ok(config) => { let ssl = config .into_ssl(&host) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; let tag = io.tag(); match connect_io(io, ssl).await { Ok(io) => { @@ -64,7 +64,10 @@ impl SslConnector { } Err(e) => { log::trace!("{}: SSL Handshake error: {:?}", tag, e); - Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)).into()) + Err( + io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)) + .into(), + ) } } } diff --git a/ntex-tls/src/openssl/mod.rs b/ntex-tls/src/openssl/mod.rs index 45ed1fcd..429b1e41 100644 --- a/ntex-tls/src/openssl/mod.rs +++ b/ntex-tls/src/openssl/mod.rs @@ -250,7 +250,9 @@ async fn handle_result( ssl::ErrorCode::WANT_READ => { let res = io.read_notify().await; match res? { - None => Err(io::Error::new(io::ErrorKind::Other, "disconnected")), + None => { + Err(io::Error::new(io::ErrorKind::NotConnected, "disconnected")) + } _ => Ok(None), } } diff --git a/ntex-tls/src/rustls/mod.rs b/ntex-tls/src/rustls/mod.rs index 1d1ef685..70f48661 100644 --- a/ntex-tls/src/rustls/mod.rs +++ b/ntex-tls/src/rustls/mod.rs @@ -24,7 +24,7 @@ pub struct PeerCertChain<'a>(pub Vec>); pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>); -impl<'a, 'b> io::Read for Wrapper<'a, 'b> { +impl io::Read for Wrapper<'_, '_> { fn read(&mut self, dst: &mut [u8]) -> io::Result { self.0.with_read_buf(|buf| { buf.with_src(|buf| { @@ -41,7 +41,7 @@ impl<'a, 'b> io::Read for Wrapper<'a, 'b> { } } -impl<'a, 'b> io::Write for Wrapper<'a, 'b> { +impl io::Write for Wrapper<'_, '_> { fn write(&mut self, src: &[u8]) -> io::Result { self.0.with_dst(|buf| buf.extend_from_slice(src)); Ok(src.len()) diff --git a/ntex-util/CHANGES.md b/ntex-util/CHANGES.md index fbb386ff..d15ad9e2 100644 --- a/ntex-util/CHANGES.md +++ b/ntex-util/CHANGES.md @@ -1,5 +1,27 @@ # Changes +## [2.10.0] - 2025-03-12 + +* Add "Inplace" channel + +* Expose "yield_to" helper + +## [2.9.0] - 2025-01-15 + +* Add EitherService/EitherServiceFactory + +* Add retry middleware + +* Add future on drop handler + +## [2.8.0] - 2024-12-04 + +* Use updated Service trait + +## [2.7.0] - 2024-12-03 + +* Add time::Sleep::elapse() method + ## [2.6.1] - 2024-11-23 * Remove debug print diff --git a/ntex-util/Cargo.toml b/ntex-util/Cargo.toml index 0e5061f4..ef999259 100644 --- a/ntex-util/Cargo.toml +++ b/ntex-util/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-util" -version = "2.6.1" +version = "2.10.0" authors = ["ntex contributors "] description = "Utilities for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -16,7 +16,7 @@ name = "ntex_util" path = "src/lib.rs" [dependencies] -ntex-service = "3.3" +ntex-service = "3.4" ntex-rt = "0.4" bitflags = "2" fxhash = "0.2" diff --git a/ntex-util/src/channel/inplace.rs b/ntex-util/src/channel/inplace.rs new file mode 100644 index 00000000..88a119fe --- /dev/null +++ b/ntex-util/src/channel/inplace.rs @@ -0,0 +1,81 @@ +//! A futures-aware bounded(1) channel. +use std::{cell::Cell, fmt, future::poll_fn, task::Context, task::Poll}; + +use crate::task::LocalWaker; + +/// Creates a new futures-aware, channel. +pub fn channel() -> Inplace { + Inplace { + value: Cell::new(None), + rx_task: LocalWaker::new(), + } +} + +/// A futures-aware bounded(1) channel. +pub struct Inplace { + value: Cell>, + rx_task: LocalWaker, +} + +// The channels do not ever project Pin to the inner T +impl Unpin for Inplace {} + +impl fmt::Debug for Inplace { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Inplace") + } +} + +impl Inplace { + /// Set a successful result. + /// + /// If the value is successfully enqueued for the remote end to receive, + /// then `Ok(())` is returned. If previose value is not consumed + /// then `Err` is returned with the value provided. + pub fn send(&self, val: T) -> Result<(), T> { + if let Some(v) = self.value.take() { + self.value.set(Some(v)); + Err(val) + } else { + self.value.set(Some(val)); + self.rx_task.wake(); + Ok(()) + } + } + + /// Wait until the oneshot is ready and return value + pub async fn recv(&self) -> T { + poll_fn(|cx| self.poll_recv(cx)).await + } + + /// Polls the oneshot to determine if value is ready + pub fn poll_recv(&self, cx: &mut Context<'_>) -> Poll { + // If we've got a value, then skip the logic below as we're done. + if let Some(val) = self.value.take() { + return Poll::Ready(val); + } + + // Check if sender is dropped and return error if it is. + self.rx_task.register(cx.waker()); + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::future::lazy; + + #[ntex_macros::rt_test2] + async fn test_inplace() { + let ch = channel(); + assert_eq!(lazy(|cx| ch.poll_recv(cx)).await, Poll::Pending); + + assert!(ch.send(1).is_ok()); + assert!(ch.send(2) == Err(2)); + assert_eq!(lazy(|cx| ch.poll_recv(cx)).await, Poll::Ready(1)); + + assert!(ch.send(1).is_ok()); + assert_eq!(ch.recv().await, 1); + } +} diff --git a/ntex-util/src/channel/mod.rs b/ntex-util/src/channel/mod.rs index a8652c6b..06e5f2f1 100644 --- a/ntex-util/src/channel/mod.rs +++ b/ntex-util/src/channel/mod.rs @@ -2,6 +2,7 @@ mod cell; pub mod condition; +pub mod inplace; pub mod mpsc; pub mod oneshot; pub mod pool; diff --git a/ntex-util/src/future/mod.rs b/ntex-util/src/future/mod.rs index 675b31eb..d92637ae 100644 --- a/ntex-util/src/future/mod.rs +++ b/ntex-util/src/future/mod.rs @@ -7,12 +7,14 @@ pub use futures_sink::Sink; mod either; mod join; mod lazy; +mod on_drop; mod ready; mod select; pub use self::either::Either; pub use self::join::{join, join_all}; pub use self::lazy::{lazy, Lazy}; +pub use self::on_drop::{OnDropFn, OnDropFuture, OnDropFutureExt}; pub use self::ready::Ready; pub use self::select::select; diff --git a/ntex-util/src/future/on_drop.rs b/ntex-util/src/future/on_drop.rs new file mode 100644 index 00000000..62e8dc7a --- /dev/null +++ b/ntex-util/src/future/on_drop.rs @@ -0,0 +1,104 @@ +use std::{cell::Cell, fmt, future::Future, pin::Pin, task::Context, task::Poll}; + +/// Execute fn during drop +pub struct OnDropFn { + f: Cell>, +} + +impl OnDropFn { + pub fn new(f: F) -> Self { + Self { + f: Cell::new(Some(f)), + } + } + + /// Cancel fn execution + pub fn cancel(&self) { + self.f.take(); + } +} + +impl fmt::Debug for OnDropFn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OnDropFn") + .field("f", &std::any::type_name::()) + .finish() + } +} + +impl Drop for OnDropFn { + fn drop(&mut self) { + if let Some(f) = self.f.take() { + f() + } + } +} + +/// Trait adds future on_drop support +pub trait OnDropFutureExt: Future + Sized { + fn on_drop(self, on_drop: F) -> OnDropFuture { + OnDropFuture::new(self, on_drop) + } +} + +impl OnDropFutureExt for F {} + +pin_project_lite::pin_project! { + pub struct OnDropFuture { + #[pin] + fut: Ft, + on_drop: OnDropFn + } +} + +impl OnDropFuture { + pub fn new(fut: Ft, on_drop: F) -> Self { + Self { + fut, + on_drop: OnDropFn::new(on_drop), + } + } +} + +impl Future for OnDropFuture { + type Output = Ft::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.fut.poll(cx) { + Poll::Ready(r) => { + this.on_drop.cancel(); + Poll::Ready(r) + } + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod test { + use std::future::{pending, poll_fn}; + + use super::*; + + #[ntex_macros::rt_test2] + async fn on_drop() { + let f = OnDropFn::new(|| ()); + assert!(format!("{:?}", f).contains("OnDropFn")); + f.cancel(); + assert!(f.f.get().is_none()); + + let mut dropped = false; + let mut f = pending::<()>().on_drop(|| { + dropped = true; + }); + poll_fn(|cx| { + let _ = Pin::new(&mut f).poll(cx); + Poll::Ready(()) + }) + .await; + + drop(f); + assert!(dropped); + } +} diff --git a/ntex-util/src/services/buffer.rs b/ntex-util/src/services/buffer.rs index 9bf1c657..5dae648c 100644 --- a/ntex-util/src/services/buffer.rs +++ b/ntex-util/src/services/buffer.rs @@ -70,7 +70,6 @@ where fn create(&self, service: S) -> Self::Service { BufferService { service: Pipeline::new(service).bind(), - service_pending: Cell::new(true), size: self.buf_size, ready: Cell::new(false), buf: RefCell::new(VecDeque::with_capacity(self.buf_size)), @@ -113,7 +112,6 @@ impl std::error::Error for BufferService pub struct BufferService> { size: usize, ready: Cell, - service_pending: Cell, service: PipelineBinding, buf: RefCell>>>, next_call: RefCell>>, @@ -131,7 +129,6 @@ where Self { size, service: Pipeline::new(service).bind(), - service_pending: Cell::new(true), ready: Cell::new(false), buf: RefCell::new(VecDeque::with_capacity(size)), next_call: RefCell::default(), @@ -158,7 +155,6 @@ where size: self.size, ready: Cell::new(false), service: self.service.clone(), - service_pending: Cell::new(false), buf: RefCell::new(VecDeque::with_capacity(self.size)), next_call: RefCell::default(), cancel_on_shutdown: self.cancel_on_shutdown, @@ -178,7 +174,6 @@ where .field("cancel_on_shutdown", &self.cancel_on_shutdown) .field("ready", &self.ready) .field("service", &self.service) - .field("service_pending", &self.service_pending) .field("buf", &self.buf) .field("next_call", &self.next_call) .finish() @@ -208,18 +203,14 @@ where if buffer.len() < self.size { // buffer next request self.ready.set(false); - self.service_pending.set(false); Poll::Ready(Ok(())) } else { log::trace!("Buffer limit exceeded"); // service is not ready - self.service_pending.set(true); let _ = self.readiness.take().map(|w| w.wake()); Poll::Pending } } else { - self.service_pending.set(false); - while let Some(sender) = buffer.pop_front() { let (next_call_tx, next_call_rx) = oneshot::channel(); if sender.send(next_call_tx).is_err() @@ -240,19 +231,6 @@ where .await } - async fn not_ready(&self) { - let fut = poll_fn(|cx| { - if self.service_pending.get() { - Poll::Ready(()) - } else { - self.readiness.set(Some(cx.waker().clone())); - Poll::Pending - } - }); - - crate::future::select(fut, self.service.get_ref().not_ready()).await; - } - async fn shutdown(&self) { // hold advancement until the last released task either makes a call or is dropped let next_call = self.next_call.borrow_mut().take(); @@ -318,6 +296,8 @@ where Ok(self.service.call(req).await?) } } + + ntex_service::forward_poll!(service); } #[cfg(test)] @@ -373,7 +353,6 @@ mod tests { let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone())).clone()).bind(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); let srv1 = srv.clone(); ntex::rt::spawn(async move { @@ -382,7 +361,6 @@ mod tests { crate::time::sleep(Duration::from_millis(25)).await; assert_eq!(inner.count.get(), 0); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); let srv1 = srv.clone(); ntex::rt::spawn(async move { @@ -391,12 +369,10 @@ mod tests { crate::time::sleep(Duration::from_millis(25)).await; assert_eq!(inner.count.get(), 0); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Ready(())); inner.ready.set(true); inner.waker.wake(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); crate::time::sleep(Duration::from_millis(25)).await; assert_eq!(inner.count.get(), 1); @@ -404,7 +380,6 @@ mod tests { inner.ready.set(true); inner.waker.wake(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); crate::time::sleep(Duration::from_millis(25)).await; assert_eq!(inner.count.get(), 2); @@ -417,12 +392,10 @@ mod tests { let srv = Pipeline::new(BufferService::new(2, TestService(inner.clone()))).bind(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); let _ = srv.call(()).await; assert_eq!(inner.count.get(), 1); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); assert!(lazy(|cx| srv.poll_shutdown(cx)).await.is_ready()); let err = BufferServiceError::from("test"); diff --git a/ntex-util/src/services/either.rs b/ntex-util/src/services/either.rs new file mode 100644 index 00000000..a17d7672 --- /dev/null +++ b/ntex-util/src/services/either.rs @@ -0,0 +1,239 @@ +//! Either service allows to use different services for handling request +use std::{fmt, task::Context}; + +use ntex_service::{Service, ServiceCtx, ServiceFactory}; + +use crate::future::Either; + +#[derive(Clone)] +/// Either service +/// +/// Either service allows to use different services for handling requests +pub struct EitherService { + svc: Either, +} + +#[derive(Clone)] +/// Either service factory +/// +/// Either service allows to use different services for handling requests +pub struct EitherServiceFactory { + left: SFLeft, + right: SFRight, + choose_left_fn: ChooseFn, +} + +impl EitherServiceFactory { + /// Create `Either` service factory + pub fn new(choose_left_fn: ChooseFn, sf_left: SFLeft, sf_right: SFRight) -> Self { + EitherServiceFactory { + choose_left_fn, + left: sf_left, + right: sf_right, + } + } +} + +impl fmt::Debug + for EitherServiceFactory +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EitherServiceFactory") + .field("left", &std::any::type_name::()) + .field("right", &std::any::type_name::()) + .field("choose_fn", &std::any::type_name::()) + .finish() + } +} + +impl ServiceFactory + for EitherServiceFactory +where + ChooseFn: Fn(&C) -> bool, + SFLeft: ServiceFactory, + SFRight: ServiceFactory< + R, + C, + Response = SFLeft::Response, + InitError = SFLeft::InitError, + Error = SFLeft::Error, + >, +{ + type Response = SFLeft::Response; + type Error = SFLeft::Error; + type InitError = SFLeft::InitError; + type Service = EitherService; + + async fn create(&self, cfg: C) -> Result { + let choose_left = (self.choose_left_fn)(&cfg); + + if choose_left { + let svc = self.left.create(cfg).await?; + Ok(EitherService { + svc: Either::Left(svc), + }) + } else { + let svc = self.right.create(cfg).await?; + Ok(EitherService { + svc: Either::Right(svc), + }) + } + } +} + +impl EitherService { + /// Create `Either` service + pub fn left(svc: SLeft) -> Self { + EitherService { + svc: Either::Left(svc), + } + } + + /// Create `Either` service + pub fn right(svc: SRight) -> Self { + EitherService { + svc: Either::Right(svc), + } + } +} + +impl fmt::Debug for EitherService { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("EitherService") + .field("left", &std::any::type_name::()) + .field("right", &std::any::type_name::()) + .finish() + } +} + +impl Service for EitherService +where + SLeft: Service, + SRight: Service, +{ + type Response = SLeft::Response; + type Error = SLeft::Error; + + #[inline] + async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { + match self.svc { + Either::Left(ref svc) => ctx.ready(svc).await, + Either::Right(ref svc) => ctx.ready(svc).await, + } + } + + #[inline] + async fn shutdown(&self) { + match self.svc { + Either::Left(ref svc) => svc.shutdown().await, + Either::Right(ref svc) => svc.shutdown().await, + } + } + + #[inline] + async fn call( + &self, + req: Req, + ctx: ServiceCtx<'_, Self>, + ) -> Result { + match self.svc { + Either::Left(ref svc) => ctx.call(svc, req).await, + Either::Right(ref svc) => ctx.call(svc, req).await, + } + } + + #[inline] + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + match self.svc { + Either::Left(ref svc) => svc.poll(cx), + Either::Right(ref svc) => svc.poll(cx), + } + } +} + +#[cfg(test)] +mod tests { + use ntex_service::{Pipeline, ServiceFactory}; + + use super::*; + + #[derive(Copy, Clone, Debug, PartialEq)] + struct Svc1; + impl Service<()> for Svc1 { + type Response = &'static str; + type Error = (); + + async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<&'static str, ()> { + Ok("svc1") + } + } + + #[derive(Clone)] + struct Svc1Factory; + impl ServiceFactory<(), &'static str> for Svc1Factory { + type Response = &'static str; + type Error = (); + type InitError = (); + type Service = Svc1; + + async fn create(&self, _: &'static str) -> Result { + Ok(Svc1) + } + } + + #[derive(Copy, Clone, Debug, PartialEq)] + struct Svc2; + impl Service<()> for Svc2 { + type Response = &'static str; + type Error = (); + + async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<&'static str, ()> { + Ok("svc2") + } + } + + #[derive(Clone)] + struct Svc2Factory; + impl ServiceFactory<(), &'static str> for Svc2Factory { + type Response = &'static str; + type Error = (); + type InitError = (); + type Service = Svc2; + + async fn create(&self, _: &'static str) -> Result { + Ok(Svc2) + } + } + + type Either = EitherService; + type EitherFactory = EitherServiceFactory; + + #[ntex_macros::rt_test2] + async fn test_success() { + let svc = Pipeline::new(Either::left(Svc1).clone()); + assert_eq!(svc.call(()).await, Ok("svc1")); + assert_eq!(svc.ready().await, Ok(())); + svc.shutdown().await; + + let svc = Pipeline::new(Either::right(Svc2).clone()); + assert_eq!(svc.call(()).await, Ok("svc2")); + assert_eq!(svc.ready().await, Ok(())); + svc.shutdown().await; + + assert!(format!("{:?}", svc).contains("EitherService")); + } + + #[ntex_macros::rt_test2] + async fn test_factory() { + let factory = + EitherFactory::new(|s: &&'static str| *s == "svc1", Svc1Factory, Svc2Factory) + .clone(); + assert!(format!("{:?}", factory).contains("EitherServiceFactory")); + + let svc = factory.pipeline("svc1").await.unwrap(); + assert_eq!(svc.call(()).await, Ok("svc1")); + + let svc = factory.pipeline("other").await.unwrap(); + assert_eq!(svc.call(()).await, Ok("svc2")); + } +} diff --git a/ntex-util/src/services/inflight.rs b/ntex-util/src/services/inflight.rs index 692a1be7..a02a4f88 100644 --- a/ntex-util/src/services/inflight.rs +++ b/ntex-util/src/services/inflight.rs @@ -71,13 +71,6 @@ where } } - #[inline] - async fn not_ready(&self) { - if self.count.is_available() { - crate::future::select(self.count.unavailable(), self.service.not_ready()).await; - } - } - #[inline] async fn call( &self, @@ -88,6 +81,7 @@ where ctx.call(&self.service, req).await } + ntex_service::forward_poll!(service); ntex_service::forward_shutdown!(service); } @@ -118,7 +112,6 @@ mod tests { let srv = Pipeline::new(InFlightService::new(1, SleepService(rx))).bind(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); let srv2 = srv.clone(); ntex::rt::spawn(async move { @@ -126,12 +119,10 @@ mod tests { }); crate::time::sleep(Duration::from_millis(25)).await; assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Ready(())); let _ = tx.send(()); crate::time::sleep(Duration::from_millis(25)).await; assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); srv.shutdown().await; } diff --git a/ntex-util/src/services/keepalive.rs b/ntex-util/src/services/keepalive.rs index c1304b21..2b9b4ad5 100644 --- a/ntex-util/src/services/keepalive.rs +++ b/ntex-util/src/services/keepalive.rs @@ -1,4 +1,4 @@ -use std::{cell::Cell, convert::Infallible, fmt, marker, time}; +use std::{cell::Cell, convert::Infallible, fmt, marker, task::Context, task::Poll, time}; use ntex_service::{Service, ServiceCtx, ServiceFactory}; @@ -119,22 +119,26 @@ where } } - async fn not_ready(&self) { - loop { - self.sleep.wait().await; - - let now = now(); - let expire = self.expire.get() + time::Duration::from(self.dur); - if expire <= now { - return; - } else { - let expire = expire - now; - self.sleep - .reset(Millis(expire.as_millis().try_into().unwrap_or(u32::MAX))); + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + match self.sleep.poll_elapsed(cx) { + Poll::Ready(_) => { + let now = now(); + let expire = self.expire.get() + time::Duration::from(self.dur); + if expire <= now { + Err((self.f)()) + } else { + let expire = expire - now; + self.sleep + .reset(Millis(expire.as_millis().try_into().unwrap_or(u32::MAX))); + let _ = self.sleep.poll_elapsed(cx); + Ok(()) + } } + Poll::Pending => Ok(()), } } + #[inline] async fn call(&self, req: R, _: ServiceCtx<'_, Self>) -> Result { self.expire.set(now()); Ok(req) @@ -162,13 +166,11 @@ mod tests { assert_eq!(service.call(1usize).await, Ok(1usize)); assert!(lazy(|cx| service.poll_ready(cx)).await.is_ready()); - assert!(!lazy(|cx| service.poll_not_ready(cx)).await.is_ready()); sleep(Millis(500)).await; assert_eq!( lazy(|cx| service.poll_ready(cx)).await, Poll::Ready(Err(TestErr)) ); - assert!(lazy(|cx| service.poll_not_ready(cx)).await.is_ready()); } } diff --git a/ntex-util/src/services/mod.rs b/ntex-util/src/services/mod.rs index 6f60afb0..5b4104f7 100644 --- a/ntex-util/src/services/mod.rs +++ b/ntex-util/src/services/mod.rs @@ -1,8 +1,10 @@ pub mod buffer; +pub mod either; mod extensions; pub mod inflight; pub mod keepalive; pub mod onerequest; +pub mod retry; pub mod timeout; pub mod variant; diff --git a/ntex-util/src/services/onerequest.rs b/ntex-util/src/services/onerequest.rs index 5e9e43c4..ab88eb51 100644 --- a/ntex-util/src/services/onerequest.rs +++ b/ntex-util/src/services/onerequest.rs @@ -65,24 +65,6 @@ where ctx.ready(&self.service).await } - #[inline] - async fn not_ready(&self) { - if self.ready.get() { - crate::future::select( - poll_fn(|cx| { - self.waker.register(cx.waker()); - if self.ready.get() { - Poll::Pending - } else { - Poll::Ready(()) - } - }), - self.service.not_ready(), - ) - .await; - } - } - #[inline] async fn call( &self, @@ -90,7 +72,6 @@ where ctx: ServiceCtx<'_, Self>, ) -> Result { self.ready.set(false); - self.waker.wake(); let result = ctx.call(&self.service, req).await; self.ready.set(true); @@ -98,6 +79,7 @@ where result } + ntex_service::forward_poll!(service); ntex_service::forward_shutdown!(service); } @@ -127,7 +109,6 @@ mod tests { let srv = Pipeline::new(OneRequestService::new(SleepService(rx))).bind(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); let srv2 = srv.clone(); ntex::rt::spawn(async move { @@ -135,12 +116,10 @@ mod tests { }); crate::time::sleep(Duration::from_millis(25)).await; assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Pending); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Ready(())); let _ = tx.send(()); crate::time::sleep(Duration::from_millis(25)).await; assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); - assert_eq!(lazy(|cx| srv.poll_not_ready(cx)).await, Poll::Pending); srv.shutdown().await; } diff --git a/ntex-util/src/services/retry.rs b/ntex-util/src/services/retry.rs new file mode 100644 index 00000000..fe5889d6 --- /dev/null +++ b/ntex-util/src/services/retry.rs @@ -0,0 +1,177 @@ +#![allow(async_fn_in_trait)] +use ntex_service::{Middleware, Service, ServiceCtx}; + +/// Trait defines retry policy +pub trait Policy>: Sized + Clone { + async fn retry(&mut self, req: &Req, res: &Result) -> bool; + + fn clone_request(&self, req: &Req) -> Option; +} + +#[derive(Clone, Debug)] +/// Retry middleware +/// +/// Retry middleware allows to retry service call +pub struct Retry

{ + policy: P, +} + +#[derive(Clone, Debug)] +/// Retry service +/// +/// Retry service allows to retry service call +pub struct RetryService { + policy: P, + service: S, +} + +impl

Retry

{ + /// Create retry middleware + pub fn new(policy: P) -> Self { + Retry { policy } + } +} + +impl Middleware for Retry

{ + type Service = RetryService; + + fn create(&self, service: S) -> Self::Service { + RetryService { + service, + policy: self.policy.clone(), + } + } +} + +impl RetryService { + /// Create retry service + pub fn new(policy: P, service: S) -> Self { + RetryService { policy, service } + } +} + +impl Service for RetryService +where + P: Policy, + S: Service, +{ + type Response = S::Response; + type Error = S::Error; + + ntex_service::forward_poll!(service); + ntex_service::forward_ready!(service); + ntex_service::forward_shutdown!(service); + + async fn call( + &self, + mut request: R, + ctx: ServiceCtx<'_, Self>, + ) -> Result { + let mut policy = self.policy.clone(); + let mut cloned = policy.clone_request(&request); + + loop { + let result = ctx.call(&self.service, request).await; + + cloned = if let Some(req) = cloned.take() { + if policy.retry(&req, &result).await { + request = req; + policy.clone_request(&request) + } else { + return result; + } + } else { + return result; + } + } + } +} + +#[derive(Copy, Clone, Debug)] +/// Default retry policy +/// +/// This policy retries on any error. By default retry count is 3 +pub struct DefaultRetryPolicy(u16); + +impl DefaultRetryPolicy { + /// Create default retry policy + pub fn new(retry: u16) -> Self { + DefaultRetryPolicy(retry) + } +} + +impl Default for DefaultRetryPolicy { + fn default() -> Self { + DefaultRetryPolicy::new(3) + } +} + +impl Policy for DefaultRetryPolicy +where + R: Clone, + S: Service, +{ + async fn retry(&mut self, _: &R, res: &Result) -> bool { + if res.is_err() { + if self.0 == 0 { + false + } else { + self.0 -= 1; + true + } + } else { + false + } + } + + fn clone_request(&self, req: &R) -> Option { + Some(req.clone()) + } +} + +#[cfg(test)] +mod tests { + use std::{cell::Cell, rc::Rc}; + + use ntex_service::{apply, fn_factory, Pipeline, ServiceFactory}; + + use super::*; + + #[derive(Clone, Debug, PartialEq)] + struct TestService(Rc>); + + impl Service<()> for TestService { + type Response = (); + type Error = (); + + async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> { + let cnt = self.0.get(); + if cnt == 0 { + Ok(()) + } else { + self.0.set(cnt - 1); + Err(()) + } + } + } + + #[ntex_macros::rt_test2] + async fn test_retry() { + let cnt = Rc::new(Cell::new(5)); + let svc = Pipeline::new( + RetryService::new(DefaultRetryPolicy::default(), TestService(cnt.clone())) + .clone(), + ); + assert_eq!(svc.call(()).await, Err(())); + assert_eq!(svc.ready().await, Ok(())); + svc.shutdown().await; + assert_eq!(cnt.get(), 1); + + let factory = apply( + Retry::new(DefaultRetryPolicy::new(3)).clone(), + fn_factory(|| async { Ok::<_, ()>(TestService(Rc::new(Cell::new(2)))) }), + ); + let srv = factory.pipeline(&()).await.unwrap(); + assert_eq!(srv.call(()).await, Ok(())); + } +} diff --git a/ntex-util/src/services/timeout.rs b/ntex-util/src/services/timeout.rs index 5c6aac4f..39905d15 100644 --- a/ntex-util/src/services/timeout.rs +++ b/ntex-util/src/services/timeout.rs @@ -140,6 +140,7 @@ where } } + ntex_service::forward_poll!(service, TimeoutError::Service); ntex_service::forward_ready!(service, TimeoutError::Service); ntex_service::forward_shutdown!(service); } @@ -209,7 +210,7 @@ mod tests { #[ntex_macros::rt_test2] #[allow(clippy::redundant_clone)] - async fn test_timeout_newservice() { + async fn test_timeout_middleware() { let resolution = Duration::from_millis(100); let wait_time = Duration::from_millis(500); diff --git a/ntex-util/src/services/variant.rs b/ntex-util/src/services/variant.rs index 7d498c3b..08b71503 100644 --- a/ntex-util/src/services/variant.rs +++ b/ntex-util/src/services/variant.rs @@ -143,23 +143,10 @@ macro_rules! variant_impl ({$mod_name:ident, $enum_type:ident, $srv_type:ident, }).await } - async fn not_ready(&self) { - use std::{future::Future, pin::Pin}; - - let mut fut1 = ::std::pin::pin!(self.V1.not_ready()); - $(let mut $T = ::std::pin::pin!(self.$T.not_ready());)+ - - ::std::future::poll_fn(|cx| { - if Pin::new(&mut fut1).poll(cx).is_ready() { - return Poll::Ready(()) - } - - $(if Pin::new(&mut $T).poll(cx).is_ready() { - return Poll::Ready(()); - })+ - - Poll::Pending - }).await + fn poll(&self, cx: &mut std::task::Context<'_>) -> Result<(), Self::Error> { + self.V1.poll(cx)?; + $(self.$T.poll(cx)?;)+ + Ok(()) } async fn shutdown(&self) { @@ -253,7 +240,6 @@ variant_impl_and!(VariantFactory7, VariantFactory8, V8, V8R, v8, (V2, V3, V4, V5 #[cfg(test)] mod tests { use ntex_service::fn_factory; - use std::{future::poll_fn, future::Future, pin}; use super::*; @@ -307,16 +293,7 @@ mod tests { let service = factory.pipeline(&()).await.unwrap().clone(); assert!(format!("{:?}", service).contains("Variant")); - let mut f = pin::pin!(service.not_ready()); - let _ = poll_fn(|cx| { - if pin::Pin::new(&mut f).poll(cx).is_pending() { - Poll::Ready(()) - } else { - Poll::Pending - } - }) - .await; - + assert!(crate::future::lazy(|cx| service.poll(cx)).await.is_ok()); assert!(service.ready().await.is_ok()); service.shutdown().await; diff --git a/ntex-util/src/task.rs b/ntex-util/src/task.rs index a2c427ce..466715a6 100644 --- a/ntex-util/src/task.rs +++ b/ntex-util/src/task.rs @@ -91,7 +91,6 @@ impl fmt::Debug for LocalWaker { } } -#[doc(hidden)] /// Yields execution back to the current runtime. pub async fn yield_to() { use std::{future::Future, pin::Pin, task::Context, task::Poll}; diff --git a/ntex-util/src/time/mod.rs b/ntex-util/src/time/mod.rs index 3d181b6c..8b97f7a4 100644 --- a/ntex-util/src/time/mod.rs +++ b/ntex-util/src/time/mod.rs @@ -101,6 +101,12 @@ impl Sleep { self.hnd.is_elapsed() } + /// Complete sleep timer. + #[inline] + pub fn elapse(&self) { + self.hnd.elapse() + } + /// Resets the `Sleep` instance to a new deadline. /// /// Calling this function allows changing the instant at which the `Sleep` @@ -354,7 +360,7 @@ impl crate::Stream for Interval { #[allow(clippy::let_underscore_future)] mod tests { use futures_util::StreamExt; - use std::time; + use std::{future::poll_fn, rc::Rc, time}; use super::*; use crate::future::lazy; @@ -449,6 +455,17 @@ mod tests { fut.await; let second_time = now(); assert!(second_time - first_time < time::Duration::from_millis(1)); + + let first_time = now(); + let fut = Rc::new(sleep(Millis(100000))); + let s = fut.clone(); + ntex::rt::spawn(async move { + s.elapse(); + }); + poll_fn(|cx| fut.poll_elapsed(cx)).await; + assert!(fut.is_elapsed()); + let second_time = now(); + assert!(second_time - first_time < time::Duration::from_millis(1)); } #[ntex_macros::rt_test2] diff --git a/ntex-util/src/time/wheel.rs b/ntex-util/src/time/wheel.rs index 3aa0bdbd..16421772 100644 --- a/ntex-util/src/time/wheel.rs +++ b/ntex-util/src/time/wheel.rs @@ -106,6 +106,11 @@ impl TimerHandle { TIMER.with(|t| t.update_timer(self.0, millis)) } + /// Resets the `TimerHandle` instance to elapsed state. + pub fn elapse(&self) { + TIMER.with(|t| t.remove_timer(self.0)) + } + pub fn is_elapsed(&self) -> bool { TIMER.with(|t| t.with_mod(|m| m.timers[self.0].bucket.is_none())) } @@ -303,6 +308,14 @@ impl Timer { }) } + /// Remove timer and wake task + fn remove_timer(&self, hnd: usize) { + self.with_mod(|inner| { + inner.remove_timer_bucket(hnd, false); + inner.timers[hnd].complete(); + }) + } + /// Update existing timer fn update_timer(&self, hnd: usize, millis: u64) { self.with_mod(|inner| { @@ -345,10 +358,6 @@ impl Timer { } }) } - - // fn remove_timer(&self, handle: usize) { - // self.0.inner.borrow_mut().remove_timer_bucket(handle, true) - // } } impl TimerMod { diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index a1f89ac7..6ef4b5ef 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,51 @@ # Changes +## [2.12.4] - 2025-03-28 + +* http: Return PayloadError::Incomplete on server disconnect + +* web: Expose WebStack for external wrapper support in downstream crates #542 + +## [2.12.3] - 2025-03-22 + +* web: Export web::app_service::AppService #534 + +* http: Add delay for test server availability, could cause connect race + +## [2.12.2] - 2025-03-15 + +* http: Allow to run publish future to completion in case error + +* http: Remove brotli support + +## [2.12.1] - 2025-03-14 + +* Allow to disable test logging (no-test-logging features) + +## [2.12.0] - 2025-03-12 + +* Add neon runtime support + +* Check test server availability before using it + +* Drop glommio support + +* Drop async-std support + +## [2.11.0] - 2025-01-31 + +* Cpu affinity support for server + +## [2.10.0] - 2024-12-04 + +* Use updated Service trait + +## [2.9.0] - 2024-11-30 + +* Fix handling unconsumed payload in h1 dispatcher #477 + +* Move body to ntex-http + ## [2.8.0] - 2024-11-04 * Use updated Service trait diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index a2490d51..0ea37469 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "2.8.0" +version = "2.12.4" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -18,7 +18,7 @@ edition = "2021" rust-version = "1.75" [package.metadata.docs.rs] -features = ["tokio", "openssl", "rustls", "compress", "cookie", "ws", "brotli", "ntex-tls/rustls-ring"] +features = ["tokio", "openssl", "rustls", "compress", "cookie", "ws", "ntex-tls/rustls-ring"] [lib] name = "ntex" @@ -45,34 +45,34 @@ url = ["url-pkg"] # tokio runtime tokio = ["ntex-net/tokio"] -# glommio runtime -glommio = ["ntex-net/glommio"] - -# async-std runtime -async-std = ["ntex-net/async-std"] - # compio runtime compio = ["ntex-net/compio"] +# neon runtime +neon = ["ntex-net/neon"] + +# neon runtime +neon-uring = ["ntex-net/neon", "ntex-net/io-uring"] + # websocket support ws = ["dep:sha-1"] -# brotli2 support -brotli = ["dep:brotli2"] +# disable [ntex::test] logging configuration +no-test-logging = [] [dependencies] ntex-codec = "0.6" -ntex-http = "0.1.12" +ntex-http = "0.1.13" ntex-router = "0.5" -ntex-service = "3.3" +ntex-service = "3.4" ntex-macros = "0.1" -ntex-util = "2.5" +ntex-util = "2.8" ntex-bytes = "0.1.27" -ntex-server = "2.5" -ntex-h2 = "1.4" -ntex-rt = "0.4.19" -ntex-io = "2.8" -ntex-net = "2.4" +ntex-server = "2.7.4" +ntex-h2 = "1.8.6" +ntex-rt = "0.4.27" +ntex-io = "2.11" +ntex-net = "2.5.10" ntex-tls = "2.3" base64 = "0.22" @@ -83,6 +83,7 @@ pin-project-lite = "0.2" regex = { version = "1.11", default-features = false, features = ["std"] } serde = { version = "1", features = ["derive"] } sha-1 = { version = "0.10", optional = true } +env_logger = { version = "0.11", default-features = false } thiserror = "1" nanorand = { version = "0.7", default-features = false, features = [ "std", @@ -108,13 +109,12 @@ tls-rustls = { version = "0.23", package = "rustls", optional = true, default-fe webpki-roots = { version = "0.26", optional = true } # compression -brotli2 = { version = "0.3.2", optional = true } flate2 = { version = "1.0", optional = true } [dev-dependencies] -env_logger = "0.11" rand = "0.8" time = "0.3" +oneshot = "0.1" futures-util = "0.3" tls-openssl = { version = "0.10", package = "openssl" } tls-rustls = { version = "0.23", package = "rustls", features = ["ring", "std"], default-features = false } diff --git a/ntex/src/http/client/connector.rs b/ntex/src/http/client/connector.rs index f0982dc9..122f2647 100644 --- a/ntex/src/http/client/connector.rs +++ b/ntex/src/http/client/connector.rs @@ -1,11 +1,11 @@ -use std::{fmt, time::Duration}; +use std::{fmt, task::Context, time::Duration}; use ntex_h2::{self as h2}; use crate::connect::{Connect as TcpConnect, Connector as TcpConnector}; use crate::service::{apply_fn, boxed, Service, ServiceCtx}; use crate::time::{Millis, Seconds}; -use crate::util::{join, select, timeout::TimeoutError, timeout::TimeoutService}; +use crate::util::{join, timeout::TimeoutError, timeout::TimeoutService}; use crate::{http::Uri, io::IoBoxed}; use super::{connection::Connection, error::ConnectError, pool::ConnectionPool, Connect}; @@ -285,12 +285,12 @@ where } #[inline] - async fn not_ready(&self) { + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + self.tcp_pool.poll(cx)?; if let Some(ref ssl_pool) = self.ssl_pool { - select(self.tcp_pool.not_ready(), ssl_pool.not_ready()).await; - } else { - self.tcp_pool.not_ready().await + ssl_pool.poll(cx)?; } + Ok(()) } async fn shutdown(&self) { diff --git a/ntex/src/http/client/h1proto.rs b/ntex/src/http/client/h1proto.rs index 06572418..28871225 100644 --- a/ntex/src/http/client/h1proto.rs +++ b/ntex/src/http/client/h1proto.rs @@ -1,13 +1,11 @@ -use std::{ - future::poll_fn, io, io::Write, pin::Pin, task::Context, task::Poll, time::Instant, -}; +use std::{future::poll_fn, io, io::Write, pin::Pin, task, task::Poll, time::Instant}; use crate::http::body::{BodySize, MessageBody}; use crate::http::error::PayloadError; -use crate::http::h1; use crate::http::header::{HeaderMap, HeaderValue, HOST}; use crate::http::message::{RequestHeadType, ResponseHead}; use crate::http::payload::{Payload, PayloadStream}; +use crate::http::{h1, Version}; use crate::io::{IoBoxed, RecvError}; use crate::time::{timeout_checked, Millis}; use crate::util::{ready, BufMut, Bytes, BytesMut, Stream}; @@ -101,7 +99,13 @@ where Ok((head, Payload::None)) } _ => { - let pl: PayloadStream = Box::pin(PlStream::new(io, codec, created, pool)); + let pl: PayloadStream = Box::pin(PlStream::new( + io, + codec, + created, + pool, + head.version == Version::HTTP_10, + )); Ok((head, pl.into())) } } @@ -137,6 +141,7 @@ pub(super) struct PlStream { io: Option, codec: h1::ClientPayloadCodec, created: Instant, + http_10: bool, pool: Option, } @@ -146,12 +151,14 @@ impl PlStream { codec: h1::ClientCodec, created: Instant, pool: Option, + http_10: bool, ) -> Self { PlStream { io: Some(io), codec: codec.into_payload_codec(), created, pool, + http_10, } } } @@ -161,41 +168,46 @@ impl Stream for PlStream { fn poll_next( mut self: Pin<&mut Self>, - cx: &mut Context<'_>, + cx: &mut task::Context<'_>, ) -> Poll> { let mut this = self.as_mut(); loop { - return Poll::Ready(Some( - match ready!(this.io.as_ref().unwrap().poll_recv(&this.codec, cx)) { - Ok(chunk) => { - if let Some(chunk) = chunk { - Ok(chunk) - } else { - release_connection( - this.io.take().unwrap(), - !this.codec.keepalive(), - this.created, - this.pool.take(), - ); - return Poll::Ready(None); - } + let item = ready!(this.io.as_ref().unwrap().poll_recv(&this.codec, cx)); + return Poll::Ready(Some(match item { + Ok(chunk) => { + if let Some(chunk) = chunk { + Ok(chunk) + } else { + release_connection( + this.io.take().unwrap(), + !this.codec.keepalive(), + this.created, + this.pool.take(), + ); + return Poll::Ready(None); } - Err(RecvError::KeepAlive) => { - Err(io::Error::new(io::ErrorKind::TimedOut, "Keep-alive").into()) + } + Err(RecvError::KeepAlive) => { + Err(io::Error::new(io::ErrorKind::TimedOut, "Keep-alive").into()) + } + Err(RecvError::Stop) => { + Err(io::Error::new(io::ErrorKind::Other, "Dispatcher stopped").into()) + } + Err(RecvError::WriteBackpressure) => { + ready!(this.io.as_ref().unwrap().poll_flush(cx, false))?; + continue; + } + Err(RecvError::Decoder(err)) => Err(err), + Err(RecvError::PeerGone(Some(err))) => { + Err(PayloadError::Incomplete(Some(err))) + } + Err(RecvError::PeerGone(None)) => { + if this.http_10 { + return Poll::Ready(None); } - Err(RecvError::Stop) => { - Err(io::Error::new(io::ErrorKind::Other, "Dispatcher stopped") - .into()) - } - Err(RecvError::WriteBackpressure) => { - ready!(this.io.as_ref().unwrap().poll_flush(cx, false))?; - continue; - } - Err(RecvError::Decoder(err)) => Err(err), - Err(RecvError::PeerGone(Some(err))) => Err(err.into()), - Err(RecvError::PeerGone(None)) => return Poll::Ready(None), - }, - )); + Err(PayloadError::Incomplete(None)) + } + })); } } } diff --git a/ntex/src/http/client/h2proto.rs b/ntex/src/http/client/h2proto.rs index e04d4763..e98209f7 100644 --- a/ntex/src/http/client/h2proto.rs +++ b/ntex/src/http/client/h2proto.rs @@ -187,14 +187,17 @@ async fn get_response( err ); pl.set_error( - io::Error::new(io::ErrorKind::Other, err) - .into(), + io::Error::new( + io::ErrorKind::UnexpectedEof, + err, + ) + .into(), ); } _ => { pl.set_error( io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::Unsupported, "unexpected h2 message", ) .into(), @@ -216,7 +219,7 @@ async fn get_response( } } _ => Err(SendRequestError::Error(Box::new(io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::Unsupported, "unexpected h2 message", )))), } diff --git a/ntex/src/http/client/pool.rs b/ntex/src/http/client/pool.rs index 7b0210c6..a56d4898 100644 --- a/ntex/src/http/client/pool.rs +++ b/ntex/src/http/client/pool.rs @@ -123,8 +123,8 @@ where } #[inline] - async fn not_ready(&self) { - self.connector.not_ready().await + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + self.connector.poll(cx) } async fn shutdown(&self) { diff --git a/ntex/src/http/client/response.rs b/ntex/src/http/client/response.rs index c68b6e73..9a450687 100644 --- a/ntex/src/http/client/response.rs +++ b/ntex/src/http/client/response.rs @@ -387,8 +387,8 @@ impl Future for ReadBody { let this = self.get_mut(); loop { - return match Pin::new(&mut this.stream).poll_next(cx)? { - Poll::Ready(Some(chunk)) => { + return match Pin::new(&mut this.stream).poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { if this.limit > 0 && (this.buf.len() + chunk.len()) > this.limit { Poll::Ready(Err(PayloadError::Overflow)) } else { @@ -397,6 +397,7 @@ impl Future for ReadBody { } } Poll::Ready(None) => Poll::Ready(Ok(this.buf.split().freeze())), + Poll::Ready(Some(Err(err))) => Poll::Ready(Err(err)), Poll::Pending => { if this.timeout.poll_elapsed(cx).is_ready() { Poll::Ready(Err(PayloadError::Incomplete(Some( diff --git a/ntex/src/http/encoding/decoder.rs b/ntex/src/http/encoding/decoder.rs index 5a518738..45020a3a 100644 --- a/ntex/src/http/encoding/decoder.rs +++ b/ntex/src/http/encoding/decoder.rs @@ -1,7 +1,5 @@ use std::{future::Future, io, io::Write, pin::Pin, task::Context, task::Poll}; -#[cfg(feature = "brotli")] -use brotli2::write::BrotliDecoder; use flate2::write::{GzDecoder, ZlibDecoder}; use super::Writer; @@ -27,10 +25,6 @@ where #[inline] pub fn new(stream: S, encoding: ContentEncoding) -> Decoder { let decoder = match encoding { - #[cfg(feature = "brotli")] - ContentEncoding::Br => Some(ContentDecoder::Br(Box::new(BrotliDecoder::new( - Writer::new(), - )))), ContentEncoding::Deflate => Some(ContentDecoder::Deflate(Box::new( ZlibDecoder::new(Writer::new()), ))), @@ -137,25 +131,11 @@ where enum ContentDecoder { Deflate(Box>), Gzip(Box>), - #[cfg(feature = "brotli")] - Br(Box>), } impl ContentDecoder { fn feed_eof(&mut self) -> io::Result> { match self { - #[cfg(feature = "brotli")] - ContentDecoder::Br(ref mut decoder) => match decoder.flush() { - Ok(()) => { - let b = decoder.get_mut().take(); - if !b.is_empty() { - Ok(Some(b)) - } else { - Ok(None) - } - } - Err(e) => Err(e), - }, ContentDecoder::Gzip(ref mut decoder) => match decoder.try_finish() { Ok(_) => { let b = decoder.get_mut().take(); @@ -183,19 +163,6 @@ impl ContentDecoder { fn feed_data(&mut self, data: Bytes) -> io::Result> { match self { - #[cfg(feature = "brotli")] - ContentDecoder::Br(ref mut decoder) => match decoder.write_all(&data) { - Ok(_) => { - decoder.flush()?; - let b = decoder.get_mut().take(); - if !b.is_empty() { - Ok(Some(b)) - } else { - Ok(None) - } - } - Err(e) => Err(e), - }, ContentDecoder::Gzip(ref mut decoder) => match decoder.write_all(&data) { Ok(_) => { decoder.flush()?; diff --git a/ntex/src/http/encoding/encoder.rs b/ntex/src/http/encoding/encoder.rs index 92003e60..086fc815 100644 --- a/ntex/src/http/encoding/encoder.rs +++ b/ntex/src/http/encoding/encoder.rs @@ -1,8 +1,6 @@ //! Stream encoder use std::{fmt, future::Future, io, io::Write, pin::Pin, task::Context, task::Poll}; -#[cfg(feature = "brotli")] -use brotli2::write::BrotliEncoder; use flate2::write::{GzEncoder, ZlibEncoder}; use crate::http::body::{Body, BodySize, MessageBody, ResponseBody}; @@ -117,7 +115,7 @@ impl MessageBody for Encoder { Poll::Ready(Ok(Err(e))) => return Poll::Ready(Some(Err(Box::new(e)))), Poll::Ready(Err(_)) => { return Poll::Ready(Some(Err(Box::new(io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::Interrupted, "Canceled", ))))); } @@ -191,23 +189,11 @@ fn update_head(encoding: ContentEncoding, head: &mut ResponseHead) { enum ContentEncoder { Deflate(ZlibEncoder), Gzip(GzEncoder), - #[cfg(feature = "brotli")] - Br(BrotliEncoder), } impl ContentEncoder { fn can_encode(encoding: ContentEncoding) -> bool { - #[cfg(feature = "brotli")] - { - matches!( - encoding, - ContentEncoding::Deflate | ContentEncoding::Gzip | ContentEncoding::Br - ) - } - #[cfg(not(feature = "brotli"))] - { - matches!(encoding, ContentEncoding::Deflate | ContentEncoding::Gzip) - } + matches!(encoding, ContentEncoding::Deflate | ContentEncoding::Gzip) } fn encoder(encoding: ContentEncoding) -> Option { @@ -220,18 +206,12 @@ impl ContentEncoder { Writer::new(), flate2::Compression::fast(), ))), - #[cfg(feature = "brotli")] - ContentEncoding::Br => { - Some(ContentEncoder::Br(BrotliEncoder::new(Writer::new(), 3))) - } _ => None, } } fn take(&mut self) -> Bytes { match *self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(ref mut encoder) => encoder.get_mut().take(), ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().take(), ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().take(), } @@ -239,11 +219,6 @@ impl ContentEncoder { fn finish(self) -> Result { match self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(encoder) => match encoder.finish() { - Ok(writer) => Ok(writer.buf.freeze()), - Err(err) => Err(err), - }, ContentEncoder::Gzip(encoder) => match encoder.finish() { Ok(writer) => Ok(writer.buf.freeze()), Err(err) => Err(err), @@ -257,14 +232,6 @@ impl ContentEncoder { fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { match *self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(ref mut encoder) => match encoder.write_all(data) { - Ok(_) => Ok(()), - Err(err) => { - log::trace!("Error decoding br encoding: {}", err); - Err(err) - } - }, ContentEncoder::Gzip(ref mut encoder) => match encoder.write_all(data) { Ok(_) => Ok(()), Err(err) => { @@ -288,8 +255,6 @@ impl fmt::Debug for ContentEncoder { match self { ContentEncoder::Deflate(_) => write!(f, "ContentEncoder::Deflate"), ContentEncoder::Gzip(_) => write!(f, "ContentEncoder::Gzip"), - #[cfg(feature = "brotli")] - ContentEncoder::Br(_) => write!(f, "ContentEncoder::Br"), } } } diff --git a/ntex/src/http/error.rs b/ntex/src/http/error.rs index 85642d6c..8703a258 100644 --- a/ntex/src/http/error.rs +++ b/ntex/src/http/error.rs @@ -29,7 +29,7 @@ pub trait ResponseError: fmt::Display + fmt::Debug { } } -impl<'a, T: ResponseError> ResponseError for &'a T { +impl ResponseError for &T { fn error_response(&self) -> Response { (*self).error_response() } @@ -217,7 +217,7 @@ pub enum BlockingError { impl From for PayloadError { fn from(_: crate::rt::JoinError) -> Self { PayloadError::Io(io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::Interrupted, "Operation is canceled", )) } @@ -228,7 +228,7 @@ impl From> for PayloadError { match err { BlockingError::Error(e) => PayloadError::Io(e), BlockingError::Canceled => PayloadError::Io(io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::Interrupted, "Operation is canceled", )), } diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index 244853b8..7a2142ea 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1,5 +1,5 @@ //! HTTP/1 protocol dispatcher -use std::{error, future, io, marker, pin::Pin, rc::Rc, task::Context, task::Poll}; +use std::{error, future, io, marker, mem, pin::Pin, rc::Rc, task::Context, task::Poll}; use crate::io::{Decoded, Filter, Io, IoStatusUpdate, RecvError}; use crate::service::{PipelineCall, Service}; @@ -144,7 +144,20 @@ where inner.send_response(res, body) } Poll::Ready(Err(err)) => inner.control(Control::err(err)), - Poll::Pending => ready!(inner.poll_request(cx)), + Poll::Pending => { + // state changed because of error. + // spawn current publish future to runtime + // so it could complete error handling + let st = ready!(inner.poll_request(cx)); + if inner.payload.is_some() { + if let State::CallPublish { fut } = + mem::replace(&mut *this.st, State::ReadRequest) + { + crate::rt::spawn(fut); + } + } + st + } }, // handle control service responses State::CallControl { fut } => match Pin::new(fut).poll(cx) { @@ -181,7 +194,13 @@ where Poll::Pending => ready!(inner.poll_request(cx)), }, // read request and call service - State::ReadRequest => ready!(inner.poll_read_request(cx)), + State::ReadRequest => { + if inner.flags.contains(Flags::SENDPAYLOAD_AND_STOP) { + inner.stop() + } else { + ready!(inner.poll_read_request(cx)) + } + } // consume request's payload State::ReadPayload => { let result = inner.poll_request_payload(cx); @@ -333,7 +352,7 @@ where .io .encode(Message::Item((msg, body.size())), &self.codec) .map_err(|err| { - if let Some(mut payload) = self.payload.take() { + if let Some(ref mut payload) = self.payload { payload.1.set_error(PayloadError::Incomplete(None)); } err @@ -432,7 +451,7 @@ where } fn set_payload_error(&mut self, err: PayloadError) { - if let Some(mut payload) = self.payload.take() { + if let Some(ref mut payload) = self.payload { payload.1.set_error(err); } } @@ -1263,4 +1282,21 @@ mod tests { assert!(mark.load(Ordering::Relaxed) == 1536); assert!(err_mark.load(Ordering::Relaxed) == 1); } + + #[crate::rt_test] + async fn test_unconsumed_payload() { + let (client, server) = Io::create(); + client.remote_buffer_cap(4096); + client.write("GET /test HTTP/1.1\r\ncontent-length:512\r\n\r\n"); + + let mut h1 = h1(server, |_| { + Box::pin(async { Ok::<_, io::Error>(Response::Ok().body("TEST")) }) + }); + // required because io shutdown is async oper + assert!(poll_fn(|cx| Pin::new(&mut h1).poll(cx)).await.is_ok()); + + assert!(h1.inner.io.is_closed()); + let buf = client.local_buffer(|buf| buf.split()); + assert_eq!(&buf[..15], b"HTTP/1.1 200 OK"); + } } diff --git a/ntex/src/http/h1/payload.rs b/ntex/src/http/h1/payload.rs index 1fe5e5a5..ac3c8609 100644 --- a/ntex/src/http/h1/payload.rs +++ b/ntex/src/http/h1/payload.rs @@ -3,8 +3,7 @@ use std::rc::{Rc, Weak}; use std::task::{Context, Poll}; use std::{cell::RefCell, collections::VecDeque, pin::Pin}; -use crate::http::error::PayloadError; -use crate::{task::LocalWaker, util::Bytes, util::Stream}; +use crate::{http::error::PayloadError, task::LocalWaker, util::Bytes, util::Stream}; /// max buffer size 32k const MAX_BUFFER_SIZE: usize = 32_768; @@ -119,7 +118,7 @@ impl PayloadSender { // we check only if Payload (other side) is alive, // otherwise always return true (consume payload) if let Some(shared) = self.inner.upgrade() { - if shared.borrow().need_read { + if shared.borrow().flags.contains(Flags::NEED_READ) { PayloadStatus::Read } else { shared.borrow_mut().io_task.register(cx.waker()); @@ -131,12 +130,20 @@ impl PayloadSender { } } +bitflags::bitflags! { + #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] + struct Flags: u8 { + const EOF = 0b0000_0001; + const ERROR = 0b0000_0010; + const NEED_READ = 0b0000_0100; + } +} + #[derive(Debug)] struct Inner { len: usize, - eof: bool, + flags: Flags, err: Option, - need_read: bool, items: VecDeque, task: LocalWaker, io_task: LocalWaker, @@ -144,12 +151,16 @@ struct Inner { impl Inner { fn new(eof: bool) -> Self { + let flags = if eof { + Flags::EOF | Flags::NEED_READ + } else { + Flags::NEED_READ + }; Inner { - eof, + flags, len: 0, err: None, items: VecDeque::new(), - need_read: true, task: LocalWaker::new(), io_task: LocalWaker::new(), } @@ -157,18 +168,23 @@ impl Inner { fn set_error(&mut self, err: PayloadError) { self.err = Some(err); + self.flags.insert(Flags::ERROR); self.task.wake() } fn feed_eof(&mut self) { - self.eof = true; + self.flags.insert(Flags::EOF); self.task.wake() } fn feed_data(&mut self, data: Bytes) { self.len += data.len(); self.items.push_back(data); - self.need_read = self.len < MAX_BUFFER_SIZE; + if self.len < MAX_BUFFER_SIZE { + self.flags.insert(Flags::NEED_READ); + } else { + self.flags.remove(Flags::NEED_READ); + } self.task.wake(); } @@ -178,19 +194,25 @@ impl Inner { ) -> Poll>> { if let Some(data) = self.items.pop_front() { self.len -= data.len(); - self.need_read = self.len < MAX_BUFFER_SIZE; + if self.len < MAX_BUFFER_SIZE { + self.flags.insert(Flags::NEED_READ); + } else { + self.flags.remove(Flags::NEED_READ); + } - if self.need_read && !self.eof { + if self.flags.contains(Flags::NEED_READ) + && !self.flags.intersects(Flags::EOF | Flags::ERROR) + { self.task.register(cx.waker()); } self.io_task.wake(); Poll::Ready(Some(Ok(data))) } else if let Some(err) = self.err.take() { Poll::Ready(Some(Err(err))) - } else if self.eof { + } else if self.flags.intersects(Flags::EOF | Flags::ERROR) { Poll::Ready(None) } else { - self.need_read = true; + self.flags.insert(Flags::NEED_READ); self.task.register(cx.waker()); self.io_task.wake(); Poll::Pending diff --git a/ntex/src/http/h1/service.rs b/ntex/src/http/h1/service.rs index e566b9ca..62c4b70f 100644 --- a/ntex/src/http/h1/service.rs +++ b/ntex/src/http/h1/service.rs @@ -1,4 +1,4 @@ -use std::{cell::Cell, cell::RefCell, error::Error, fmt, marker, rc::Rc}; +use std::{cell::Cell, cell::RefCell, error::Error, fmt, marker, rc::Rc, task::Context}; use crate::http::body::MessageBody; use crate::http::config::{DispatcherConfig, ServiceConfig}; @@ -6,7 +6,7 @@ use crate::http::error::{DispatchError, ResponseError}; use crate::http::{request::Request, response::Response}; use crate::io::{types, Filter, Io, IoRef}; use crate::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; -use crate::{channel::oneshot, util::join, util::select, util::HashSet}; +use crate::{channel::oneshot, util::join, util::HashSet}; use super::control::{Control, ControlAck}; use super::default::DefaultControlService; @@ -230,10 +230,14 @@ where }) } - #[inline] - async fn not_ready(&self) { + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { let cfg = self.config.as_ref(); - select(cfg.control.not_ready(), cfg.service.not_ready()).await; + cfg.control + .poll(cx) + .map_err(|e| DispatchError::Control(Box::new(e)))?; + cfg.service + .poll(cx) + .map_err(|e| DispatchError::Service(Box::new(e))) } async fn shutdown(&self) { @@ -286,7 +290,7 @@ where let mut inflight = self.inflight.borrow_mut(); inflight.remove(&ioref); - if inflight.len() == 0 { + if inflight.is_empty() { if let Some(tx) = self.tx.take() { let _ = tx.send(()); } diff --git a/ntex/src/http/h2/service.rs b/ntex/src/http/h2/service.rs index 00889942..06fb0986 100644 --- a/ntex/src/http/h2/service.rs +++ b/ntex/src/http/h2/service.rs @@ -1,5 +1,5 @@ use std::cell::{Cell, RefCell}; -use std::{error::Error, fmt, future::poll_fn, io, marker, mem, rc::Rc}; +use std::{error::Error, fmt, future::poll_fn, io, marker, mem, rc::Rc, task::Context}; use ntex_h2::{self as h2, frame::StreamId, server}; @@ -227,8 +227,11 @@ where } #[inline] - async fn not_ready(&self) { - self.config.service.not_ready().await; + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + self.config + .service + .poll(cx) + .map_err(|e| DispatchError::Service(Box::new(e))) } #[inline] @@ -283,7 +286,7 @@ where let mut inflight = self.inflight.borrow_mut(); inflight.remove(&ioref); - if inflight.len() == 0 { + if inflight.is_empty() { if let Some(tx) = self.tx.take() { let _ = tx.send(()); } @@ -405,7 +408,9 @@ where h2::MessageKind::Disconnect(err) => { log::debug!("Connection is disconnected {:?}", err); if let Some(mut sender) = self.streams.borrow_mut().remove(&stream.id()) { - sender.set_error(io::Error::new(io::ErrorKind::Other, err).into()); + sender.set_error( + io::Error::new(io::ErrorKind::UnexpectedEof, err).into(), + ); } return Ok(()); } diff --git a/ntex/src/http/helpers.rs b/ntex/src/http/helpers.rs index 588aafce..ed3a26b5 100644 --- a/ntex/src/http/helpers.rs +++ b/ntex/src/http/helpers.rs @@ -6,7 +6,7 @@ use crate::util::BytesMut; pub(crate) struct Writer<'a>(pub(crate) &'a mut BytesMut); -impl<'a> io::Write for Writer<'a> { +impl io::Write for Writer<'_> { fn write(&mut self, buf: &[u8]) -> io::Result { self.0.extend_from_slice(buf); Ok(buf.len()) diff --git a/ntex/src/http/mod.rs b/ntex/src/http/mod.rs index 29d79bb3..3190259f 100644 --- a/ntex/src/http/mod.rs +++ b/ntex/src/http/mod.rs @@ -1,5 +1,4 @@ //! Http protocol support. -pub mod body; mod builder; pub mod client; mod config; @@ -36,4 +35,4 @@ pub use crate::io::types::HttpProtocol; // re-exports pub use ntex_http::uri::{self, Uri}; -pub use ntex_http::{HeaderMap, Method, StatusCode, Version}; +pub use ntex_http::{body, HeaderMap, Method, StatusCode, Version}; diff --git a/ntex/src/http/response.rs b/ntex/src/http/response.rs index faea582c..9c68c6fe 100644 --- a/ntex/src/http/response.rs +++ b/ntex/src/http/response.rs @@ -227,6 +227,20 @@ impl Response { } } +#[cfg(test)] +impl Response { + pub(crate) fn get_body_ref(&self) -> &[u8] { + let b = match *self.body() { + ResponseBody::Body(ref b) => b, + ResponseBody::Other(ref b) => b, + }; + match b { + Body::Bytes(bin) => bin, + _ => panic!(), + } + } +} + impl fmt::Debug for Response { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let res = writeln!( @@ -925,7 +939,7 @@ mod tests { let resp = Response::build(StatusCode::OK).json(&vec!["v1", "v2", "v3"]); let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("application/json")); - assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); + assert_eq!(resp.get_body_ref(), b"[\"v1\",\"v2\",\"v3\"]"); } #[test] @@ -935,14 +949,7 @@ mod tests { .json(&vec!["v1", "v2", "v3"]); let ct = resp.headers().get(CONTENT_TYPE).unwrap(); assert_eq!(ct, HeaderValue::from_static("text/json")); - assert_eq!(resp.body().get_ref(), b"[\"v1\",\"v2\",\"v3\"]"); - } - - #[test] - fn test_serde_json_in_body() { - use serde_json::json; - let resp = Response::build(StatusCode::OK).body(json!({"test-key":"test-value"})); - assert_eq!(resp.body().get_ref(), br#"{"test-key":"test-value"}"#); + assert_eq!(resp.get_body_ref(), b"[\"v1\",\"v2\",\"v3\"]"); } #[test] @@ -955,7 +962,7 @@ mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); let resp: Response = b"test".as_ref().into(); assert_eq!(resp.status(), StatusCode::OK); @@ -964,7 +971,7 @@ mod tests { HeaderValue::from_static("application/octet-stream") ); assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); let resp: Response = "test".to_owned().into(); assert_eq!(resp.status(), StatusCode::OK); @@ -973,7 +980,7 @@ mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); let resp: Response = (&"test".to_owned()).into(); assert_eq!(resp.status(), StatusCode::OK); @@ -982,7 +989,7 @@ mod tests { HeaderValue::from_static("text/plain; charset=utf-8") ); assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); let b = Bytes::from_static(b"test"); let resp: Response = b.into(); @@ -992,7 +999,7 @@ mod tests { HeaderValue::from_static("application/octet-stream") ); assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); let b = Bytes::from_static(b"test"); let resp: Response = b.into(); @@ -1002,7 +1009,7 @@ mod tests { HeaderValue::from_static("application/octet-stream") ); assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); let b = BytesMut::from("test"); let resp: Response = b.into(); @@ -1013,7 +1020,7 @@ mod tests { ); assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); let builder = Response::build_from(ResponseBuilder::new(StatusCode::OK)) .keep_alive() diff --git a/ntex/src/http/service.rs b/ntex/src/http/service.rs index 42156e63..6174146c 100644 --- a/ntex/src/http/service.rs +++ b/ntex/src/http/service.rs @@ -1,8 +1,8 @@ -use std::{cell::Cell, cell::RefCell, error, fmt, marker, rc::Rc}; +use std::{cell::Cell, cell::RefCell, error, fmt, marker, rc::Rc, task::Context}; use crate::io::{types, Filter, Io, IoRef}; use crate::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; -use crate::{channel::oneshot, util::join, util::select, util::HashSet}; +use crate::{channel::oneshot, util::join, util::HashSet}; use super::body::MessageBody; use super::builder::HttpServiceBuilder; @@ -312,12 +312,16 @@ where } #[inline] - async fn not_ready(&self) { + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { let cfg = self.config.as_ref(); - select(cfg.control.not_ready(), cfg.service.not_ready()).await; + cfg.control + .poll(cx) + .map_err(|e| DispatchError::Control(Box::new(e)))?; + cfg.service + .poll(cx) + .map_err(|e| DispatchError::Service(Box::new(e))) } - #[inline] async fn shutdown(&self) { self.config.shutdown(); @@ -383,7 +387,7 @@ where let mut inflight = self.inflight.borrow_mut(); inflight.remove(&ioref); - if inflight.len() == 0 { + if inflight.is_empty() { if let Some(tx) = self.tx.take() { let _ = tx.send(()); } diff --git a/ntex/src/http/test.rs b/ntex/src/http/test.rs index 1a395eae..0e4a6559 100644 --- a/ntex/src/http/test.rs +++ b/ntex/src/http/test.rs @@ -8,10 +8,10 @@ use coo_kie::{Cookie, CookieJar}; use crate::io::Filter; use crate::io::Io; use crate::server::Server; +use crate::service::ServiceFactory; #[cfg(feature = "ws")] use crate::ws::{error::WsClientError, WsClient, WsConnection}; -use crate::{rt::System, service::ServiceFactory}; -use crate::{time::Millis, time::Seconds, util::Bytes}; +use crate::{rt::System, time::sleep, time::Millis, time::Seconds, util::Bytes}; use super::client::{Client, ClientRequest, ClientResponse, Connector}; use super::error::{HttpError, PayloadError}; @@ -244,10 +244,15 @@ where .workers(1) .disable_signals() .run(); - tx.send((system, srv, local_addr)).unwrap(); + + crate::rt::spawn(async move { + sleep(Millis(125)).await; + tx.send((system, srv, local_addr)).unwrap(); + }); Ok(()) }) }); + thread::sleep(std::time::Duration::from_millis(150)); let (system, server, addr) = rx.recv().unwrap(); @@ -257,7 +262,7 @@ where server, client: Client::build().finish(), } - .set_client_timeout(Seconds(30), Millis(30_000)) + .set_client_timeout(Seconds(90), Millis(90_000)) } #[derive(Debug)] diff --git a/ntex/src/lib.rs b/ntex/src/lib.rs index 6a6a02cb..a15ee31a 100644 --- a/ntex/src/lib.rs +++ b/ntex/src/lib.rs @@ -28,7 +28,7 @@ pub use ntex_macros::{rt_main as main, rt_test as test}; #[cfg(test)] pub(crate) use ntex_macros::rt_test2 as rt_test; -pub use ntex_service::{forward_ready, forward_shutdown}; +pub use ntex_service::{forward_poll, forward_ready, forward_shutdown}; pub mod http; pub mod web; @@ -123,4 +123,15 @@ pub mod util { #[doc(hidden)] #[deprecated] pub use std::task::ready; + + #[doc(hidden)] + pub fn enable_test_logging() { + #[cfg(not(feature = "no-test-logging"))] + if std::env::var("NTEX_NO_TEST_LOG").is_err() { + if std::env::var("RUST_LOG").is_err() { + std::env::set_var("RUST_LOG", "trace"); + } + let _ = env_logger::builder().is_test(true).try_init(); + } + } } diff --git a/ntex/src/web/app_service.rs b/ntex/src/web/app_service.rs index 7752c1c2..b4b0ff9d 100644 --- a/ntex/src/web/app_service.rs +++ b/ntex/src/web/app_service.rs @@ -1,11 +1,11 @@ -use std::{cell::RefCell, marker, rc::Rc}; +use std::{cell::RefCell, marker, rc::Rc, task::Context}; use crate::http::{Request, Response}; use crate::router::{Path, ResourceDef, Router}; use crate::service::boxed::{self, BoxService, BoxServiceFactory}; use crate::service::dev::ServiceChainFactory; use crate::service::{fn_service, Middleware, Service, ServiceCtx, ServiceFactory}; -use crate::util::{join, select, BoxFuture, Extensions}; +use crate::util::{join, BoxFuture, Extensions}; use super::config::AppConfig; use super::error::ErrorRenderer; @@ -202,6 +202,7 @@ where type Response = WebResponse; type Error = T::Error; + crate::forward_poll!(service); crate::forward_ready!(service); crate::forward_shutdown!(service); @@ -302,8 +303,9 @@ where } #[inline] - async fn not_ready(&self) { - select(self.filter.not_ready(), self.routing.not_ready()).await; + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + self.filter.poll(cx)?; + self.routing.poll(cx) } async fn call( diff --git a/ntex/src/web/config.rs b/ntex/src/web/config.rs index 91c5034c..c7edceb6 100644 --- a/ntex/src/web/config.rs +++ b/ntex/src/web/config.rs @@ -68,7 +68,7 @@ pub struct ServiceConfig { } impl ServiceConfig { - pub(crate) fn new() -> Self { + pub fn new() -> Self { Self { services: Vec::new(), state: Extensions::new(), @@ -132,7 +132,7 @@ mod tests { use crate::http::{Method, StatusCode}; use crate::util::Bytes; use crate::web::test::{call_service, init_service, read_body, TestRequest}; - use crate::web::{self, App, HttpRequest, HttpResponse}; + use crate::web::{self, App, DefaultError, HttpRequest, HttpResponse}; #[crate::rt_test] async fn test_configure_state() { @@ -205,4 +205,11 @@ mod tests { let resp = call_service(&srv, req).await; assert_eq!(resp.status(), StatusCode::OK); } + + #[test] + fn test_new_service_config() { + let cfg: ServiceConfig = ServiceConfig::new(); + assert!(cfg.services.is_empty()); + assert!(cfg.external.is_empty()); + } } diff --git a/ntex/src/web/middleware/compress.rs b/ntex/src/web/middleware/compress.rs index 3d8c1db7..dc848ab1 100644 --- a/ntex/src/web/middleware/compress.rs +++ b/ntex/src/web/middleware/compress.rs @@ -67,6 +67,7 @@ where type Response = WebResponse; type Error = S::Error; + crate::forward_poll!(service); crate::forward_ready!(service); crate::forward_shutdown!(service); diff --git a/ntex/src/web/middleware/defaultheaders.rs b/ntex/src/web/middleware/defaultheaders.rs index 5aa0461e..670361be 100644 --- a/ntex/src/web/middleware/defaultheaders.rs +++ b/ntex/src/web/middleware/defaultheaders.rs @@ -110,6 +110,7 @@ where type Response = WebResponse; type Error = S::Error; + crate::forward_poll!(service); crate::forward_ready!(service); crate::forward_shutdown!(service); diff --git a/ntex/src/web/middleware/logger.rs b/ntex/src/web/middleware/logger.rs index cc5bbfeb..26f4f4a6 100644 --- a/ntex/src/web/middleware/logger.rs +++ b/ntex/src/web/middleware/logger.rs @@ -139,6 +139,7 @@ where type Response = WebResponse; type Error = S::Error; + crate::forward_poll!(service); crate::forward_ready!(service); crate::forward_shutdown!(service); @@ -399,7 +400,7 @@ pub(crate) struct FormatDisplay<'a>( &'a dyn Fn(&mut fmt::Formatter<'_>) -> Result<(), fmt::Error>, ); -impl<'a> fmt::Display for FormatDisplay<'a> { +impl fmt::Display for FormatDisplay<'_> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { (self.0)(fmt) } diff --git a/ntex/src/web/mod.rs b/ntex/src/web/mod.rs index cf1686e0..8d9adf4d 100644 --- a/ntex/src/web/mod.rs +++ b/ntex/src/web/mod.rs @@ -82,7 +82,7 @@ mod route; mod scope; mod server; mod service; -mod stack; +pub mod stack; pub mod test; pub mod types; mod util; @@ -128,6 +128,7 @@ pub mod dev { //! The purpose of this module is to alleviate imports of many common //! traits by adding a glob import to the top of ntex::web heavy modules: + pub use crate::web::app_service::AppService; pub use crate::web::config::AppConfig; pub use crate::web::info::ConnectionInfo; pub use crate::web::rmap::ResourceMap; diff --git a/ntex/src/web/responder.rs b/ntex/src/web/responder.rs index b42ee506..55e3a12f 100644 --- a/ntex/src/web/responder.rs +++ b/ntex/src/web/responder.rs @@ -143,7 +143,7 @@ impl Responder for String { } } -impl<'a, Err: ErrorRenderer> Responder for &'a String { +impl Responder for &String { async fn respond_to(self, _: &HttpRequest) -> Response { Response::build(StatusCode::OK) .content_type("text/plain; charset=utf-8") @@ -371,7 +371,7 @@ pub(crate) mod tests { let resp: HttpResponse = responder("test").respond_to(&req).await; assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") @@ -379,7 +379,7 @@ pub(crate) mod tests { let resp: HttpResponse = responder(&b"test"[..]).respond_to(&req).await; assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") @@ -387,7 +387,7 @@ pub(crate) mod tests { let resp: HttpResponse = responder("test".to_string()).respond_to(&req).await; assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") @@ -395,7 +395,7 @@ pub(crate) mod tests { let resp: HttpResponse = responder(&"test".to_string()).respond_to(&req).await; assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") @@ -405,7 +405,7 @@ pub(crate) mod tests { .respond_to(&req) .await; assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") @@ -415,7 +415,7 @@ pub(crate) mod tests { .respond_to(&req) .await; assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("application/octet-stream") @@ -440,7 +440,7 @@ pub(crate) mod tests { ) .await; assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.body().get_ref(), b"test"); + assert_eq!(resp.get_body_ref(), b"test"); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("text/plain; charset=utf-8") @@ -463,7 +463,7 @@ pub(crate) mod tests { .respond_to(&req) .await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); - assert_eq!(res.body().get_ref(), b"test"); + assert_eq!(res.get_body_ref(), b"test"); let res = responder("test".to_string()) .with_header("content-type", "json") @@ -471,7 +471,7 @@ pub(crate) mod tests { .await; assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.body().get_ref(), b"test"); + assert_eq!(res.get_body_ref(), b"test"); assert_eq!( res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("json") @@ -487,7 +487,7 @@ pub(crate) mod tests { ) .await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); - assert_eq!(res.body().get_ref(), b"test"); + assert_eq!(res.get_body_ref(), b"test"); let req = TestRequest::default().to_http_request(); let res = @@ -496,7 +496,7 @@ pub(crate) mod tests { .respond_to(&req) .await; assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.body().get_ref(), b"test"); + assert_eq!(res.get_body_ref(), b"test"); assert_eq!( res.headers().get(CONTENT_TYPE).unwrap(), HeaderValue::from_static("json") diff --git a/ntex/src/web/scope.rs b/ntex/src/web/scope.rs index 6e462cec..a3bc0458 100644 --- a/ntex/src/web/scope.rs +++ b/ntex/src/web/scope.rs @@ -1,11 +1,11 @@ -use std::{cell::RefCell, fmt, rc::Rc}; +use std::{cell::RefCell, fmt, rc::Rc, task::Context}; use crate::http::Response; use crate::router::{IntoPattern, ResourceDef, Router}; use crate::service::boxed::{self, BoxService, BoxServiceFactory}; use crate::service::{chain_factory, dev::ServiceChainFactory, IntoServiceFactory}; use crate::service::{Identity, Middleware, Service, ServiceCtx, ServiceFactory}; -use crate::util::{join, select, Extensions}; +use crate::util::{join, Extensions}; use super::app::Filter; use super::config::ServiceConfig; @@ -495,10 +495,12 @@ where } #[inline] - async fn not_ready(&self) { - select(self.filter.not_ready(), self.routing.not_ready()).await; + fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> { + self.filter.poll(cx)?; + self.routing.poll(cx) } + #[inline] async fn call( &self, req: WebRequest, diff --git a/ntex/src/web/server.rs b/ntex/src/web/server.rs index 9efe14cb..05bb12b6 100644 --- a/ntex/src/web/server.rs +++ b/ntex/src/web/server.rs @@ -310,6 +310,14 @@ where self } + /// Enable cpu affinity + /// + /// By default affinity is disabled. + pub fn enable_affinity(mut self) -> Self { + self.builder = self.builder.enable_affinity(); + self + } + /// Set io tag for web server pub fn tag(self, tag: &'static str) -> Self { self.config.lock().unwrap().tag = tag; @@ -467,7 +475,7 @@ where Err(e) } else { Err(io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::InvalidInput, "Cannot bind to address.", )) } diff --git a/ntex/src/web/stack.rs b/ntex/src/web/stack.rs index e708102b..d8c8643f 100644 --- a/ntex/src/web/stack.rs +++ b/ntex/src/web/stack.rs @@ -73,6 +73,7 @@ where ctx.call(&self.svc, req).await.map_err(Into::into) } + crate::forward_poll!(svc); crate::forward_ready!(svc); crate::forward_shutdown!(svc); } diff --git a/ntex/src/web/test.rs b/ntex/src/web/test.rs index 20c65fd0..1307ad9f 100644 --- a/ntex/src/web/test.rs +++ b/ntex/src/web/test.rs @@ -697,7 +697,10 @@ where .set_tag("test", "WEB-SRV") .run(); - tx.send((System::current(), srv, local_addr)).unwrap(); + crate::rt::spawn(async move { + sleep(Millis(125)).await; + tx.send((System::current(), srv, local_addr)).unwrap(); + }); Ok(()) }) }); @@ -717,8 +720,8 @@ where .map_err(|e| log::error!("Cannot set alpn protocol: {:?}", e)); Connector::default() .lifetime(Seconds::ZERO) - .keep_alive(Seconds(30)) - .timeout(Millis(30_000)) + .keep_alive(Seconds(60)) + .timeout(Millis(90_000)) .disconnect_timeout(Seconds(5)) .openssl(builder.build()) .finish() @@ -727,14 +730,14 @@ where { Connector::default() .lifetime(Seconds::ZERO) - .timeout(Millis(30_000)) + .timeout(Millis(90_000)) .finish() } }; Client::build() .connector(connector) - .timeout(Seconds(30)) + .timeout(Seconds(90)) .finish() }; @@ -929,7 +932,7 @@ impl TestServer { WsClient::build(self.url(path)) .address(self.addr) - .timeout(Seconds(30)) + .timeout(Seconds(60)) .openssl(builder.build()) .take() .finish() @@ -945,7 +948,7 @@ impl TestServer { } else { WsClient::build(self.url(path)) .address(self.addr) - .timeout(Seconds(30)) + .timeout(Seconds(60)) .finish() .unwrap() .connect() diff --git a/ntex/src/web/types/form.rs b/ntex/src/web/types/form.rs index 5605aa7f..1b69407c 100644 --- a/ntex/src/web/types/form.rs +++ b/ntex/src/web/types/form.rs @@ -493,6 +493,6 @@ mod tests { HeaderValue::from_static("application/x-www-form-urlencoded") ); - assert_eq!(resp.body().get_ref(), b"hello=world&counter=123"); + assert_eq!(resp.get_body_ref(), b"hello=world&counter=123"); } } diff --git a/ntex/src/web/types/json.rs b/ntex/src/web/types/json.rs index 580b1db8..cddbfd70 100644 --- a/ntex/src/web/types/json.rs +++ b/ntex/src/web/types/json.rs @@ -294,7 +294,7 @@ where let json = if let Ok(Some(mime)) = req.mime_type() { mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) - || ctype.as_ref().map_or(false, |predicate| predicate(mime)) + || ctype.as_ref().is_some_and(|predicate| predicate(mime)) } else { false }; @@ -434,7 +434,7 @@ mod tests { header::HeaderValue::from_static("application/json") ); - assert_eq!(resp.body().get_ref(), b"{\"name\":\"test\"}"); + assert_eq!(resp.get_body_ref(), b"{\"name\":\"test\"}"); } #[crate::rt_test] diff --git a/ntex/src/ws/transport.rs b/ntex/src/ws/transport.rs index f6e27f7a..78473fab 100644 --- a/ntex/src/ws/transport.rs +++ b/ntex/src/ws/transport.rs @@ -54,7 +54,7 @@ impl WsTransport { Ok(()) } else { self.insert_flags(Flags::PROTO_ERR); - Err(io::Error::new(io::ErrorKind::Other, err_message)) + Err(io::Error::new(io::ErrorKind::InvalidData, err_message)) } } } @@ -96,7 +96,7 @@ impl FilterLayer for WsTransport { self.codec.decode_vec(&mut src).map_err(|e| { log::trace!("Failed to decode ws codec frames: {:?}", e); self.insert_flags(Flags::PROTO_ERR); - io::Error::new(io::ErrorKind::Other, e) + io::Error::new(io::ErrorKind::InvalidData, e) })? { frame } else { @@ -123,14 +123,14 @@ impl FilterLayer for WsTransport { Frame::Continuation(Item::FirstText(_)) => { self.insert_flags(Flags::PROTO_ERR); return Err(io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::InvalidData, "WebSocket Text continuation frames are not supported", )); } Frame::Text(_) => { self.insert_flags(Flags::PROTO_ERR); return Err(io::Error::new( - io::ErrorKind::Other, + io::ErrorKind::InvalidData, "WebSockets Text frames are not supported", )); } diff --git a/ntex/tests/http_awc_client.rs b/ntex/tests/http_awc_client.rs index eb18624b..bd4c7e0a 100644 --- a/ntex/tests/http_awc_client.rs +++ b/ntex/tests/http_awc_client.rs @@ -3,16 +3,14 @@ use std::io::{Read, Write}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use brotli2::write::BrotliEncoder; use coo_kie::Cookie; use flate2::{read::GzDecoder, write::GzEncoder, write::ZlibEncoder, Compression}; -use futures_util::stream::once; use rand::Rng; -use ntex::http::client::error::{JsonPayloadError, SendRequestError}; +use ntex::http::client::error::SendRequestError; use ntex::http::client::{Client, Connector}; use ntex::http::test::server as test_server; -use ntex::http::{header, HttpMessage, HttpService, Method}; +use ntex::http::{header, HttpMessage, HttpService}; use ntex::service::{chain_factory, map_config}; use ntex::web::dev::AppConfig; use ntex::web::middleware::Compress; @@ -220,7 +218,7 @@ async fn test_connection_reuse() { ))) }); - let client = Client::build().timeout(Seconds(10)).finish(); + let client = Client::build().timeout(Seconds(30)).finish(); // req 1 let request = client.get(srv.url("/")).send(); @@ -255,7 +253,7 @@ async fn test_connection_force_close() { ))) }); - let client = Client::build().timeout(Seconds(10)).finish(); + let client = Client::build().timeout(Seconds(30)).finish(); // req 1 let request = client.get(srv.url("/")).force_close().send(); @@ -263,7 +261,7 @@ async fn test_connection_force_close() { assert!(response.status().is_success()); // req 2 - let client = Client::build().timeout(Seconds(10)).finish(); + let client = Client::build().timeout(Seconds(30)).finish(); let req = client.post(srv.url("/")).force_close(); let response = req.send().await.unwrap(); assert!(response.status().is_success()); @@ -291,7 +289,7 @@ async fn test_connection_server_close() { ))) }); - let client = Client::build().timeout(Seconds(10)).finish(); + let client = Client::build().timeout(Seconds(30)).finish(); // req 1 let request = client.get(srv.url("/")).send(); @@ -510,19 +508,21 @@ async fn test_client_gzip_encoding_large() { async fn test_client_gzip_encoding_large_random() { let data = rand::thread_rng() .sample_iter(&rand::distributions::Alphanumeric) - .take(100_000) + .take(1_048_500) .map(char::from) .collect::(); let srv = test::server(|| { - App::new().service(web::resource("/").route(web::to(|data: Bytes| async move { - let mut e = GzEncoder::new(Vec::new(), Compression::default()); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); - HttpResponse::Ok() - .header("content-encoding", "gzip") - .body(data) - }))) + App::new() + .state(web::types::PayloadConfig::default().limit(1_048_576)) + .service(web::resource("/").route(web::to(|data: Bytes| async move { + let mut e = GzEncoder::new(Vec::new(), Compression::default()); + e.write_all(&data).unwrap(); + let data = e.finish().unwrap(); + HttpResponse::Ok() + .header("content-encoding", "gzip") + .body(data) + }))) }); // client request @@ -530,130 +530,10 @@ async fn test_client_gzip_encoding_large_random() { assert!(response.status().is_success()); // read response - let bytes = response.body().await.unwrap(); + let bytes = response.body().limit(1_048_576).await.unwrap(); assert_eq!(bytes, Bytes::from(data)); } -#[ntex::test] -async fn test_client_brotli_encoding() { - let srv = test::server(|| { - App::new().service(web::resource("/").route(web::to(|data: Bytes| async move { - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); - HttpResponse::Ok() - .header("content-encoding", "br") - .body(data) - }))) - }); - - // client request - let mut response = srv.post("/").send_body(STR).await.unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[ntex::test] -async fn test_client_brotli_encoding_large_random() { - let data = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(70_000) - .map(char::from) - .collect::(); - - let srv = test::server(|| { - App::new().service(web::resource("/").route(web::to(|data: Bytes| async move { - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(&data).unwrap(); - let data = e.finish().unwrap(); - HttpResponse::Ok() - .header("content-encoding", "br") - .body(data) - }))) - }); - - // client request - let mut response = srv.post("/").send_body(data.clone()).await.unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data.clone())); - - // frozen request - let request = srv.post("/").timeout(Seconds(30)).freeze().unwrap(); - assert_eq!(request.get_method(), Method::POST); - assert_eq!(request.get_uri(), srv.url("/").as_str()); - let mut response = request.send_body(data.clone()).await.unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data.clone())); - - // extra header - let mut response = request - .extra_header("x-test2", "222") - .send_body(data.clone()) - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data.clone())); - - // client stream request - let mut response = srv - .post("/") - .send_stream(once(Ready::Ok::<_, JsonPayloadError>(Bytes::from( - data.clone(), - )))) - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data.clone())); - - // frozen request - let request = srv.post("/").timeout(Seconds(30)).freeze().unwrap(); - let mut response = request - .send_stream(once(Ready::Ok::<_, JsonPayloadError>(Bytes::from( - data.clone(), - )))) - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data.clone())); - - let mut response = request - .extra_header("x-test2", "222") - .send_stream(once(Ready::Ok::<_, JsonPayloadError>(Bytes::from( - data.clone(), - )))) - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes.len(), data.len()); - assert_eq!(bytes, Bytes::from(data.clone())); -} - #[ntex::test] async fn test_client_deflate_encoding() { let srv = test::server(|| { @@ -814,7 +694,7 @@ async fn client_read_until_eof() { // client request let req = Client::build() - .timeout(Seconds(5)) + .timeout(Seconds(30)) .finish() .get(format!("http://{}/", addr).as_str()); let mut response = req.send().await.unwrap(); diff --git a/ntex/tests/http_openssl.rs b/ntex/tests/http_openssl.rs index c91de0b8..921310a8 100644 --- a/ntex/tests/http_openssl.rs +++ b/ntex/tests/http_openssl.rs @@ -1,5 +1,6 @@ #![cfg(feature = "openssl")] -use std::{io, sync::atomic::AtomicUsize, sync::atomic::Ordering, sync::Arc}; +use std::io; +use std::sync::{atomic::AtomicUsize, atomic::Ordering, Arc, Mutex}; use futures_util::stream::{once, Stream, StreamExt}; use tls_openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; @@ -424,11 +425,12 @@ async fn test_h2_service_error() { assert_eq!(bytes, Bytes::from_static(b"error")); } -struct SetOnDrop(Arc); +struct SetOnDrop(Arc, Arc>>>); impl Drop for SetOnDrop { fn drop(&mut self) { self.0.fetch_add(1, Ordering::Relaxed); + let _ = self.1.lock().unwrap().take().unwrap().send(()); } } @@ -436,17 +438,20 @@ impl Drop for SetOnDrop { async fn test_h2_client_drop() -> io::Result<()> { let count = Arc::new(AtomicUsize::new(0)); let count2 = count.clone(); + let (tx, rx) = ::oneshot::channel(); + let tx = Arc::new(Mutex::new(Some(tx))); let srv = test_server(move || { + let tx = tx.clone(); let count = count2.clone(); HttpService::build() .h2(move |req: Request| { - let count = count.clone(); + let st = SetOnDrop(count.clone(), tx.clone()); async move { - let _st = SetOnDrop(count); assert!(req.peer_addr().is_some()); assert_eq!(req.version(), Version::HTTP_2); - sleep(Seconds(100)).await; + sleep(Seconds(30)).await; + drop(st); Ok::<_, io::Error>(Response::Ok().finish()) } }) @@ -454,9 +459,9 @@ async fn test_h2_client_drop() -> io::Result<()> { .map_err(|_| ()) }); - let result = timeout(Millis(250), srv.srequest(Method::GET, "/").send()).await; + let result = timeout(Millis(1500), srv.srequest(Method::GET, "/").send()).await; assert!(result.is_err()); - sleep(Millis(150)).await; + let _ = timeout(Millis(1500), rx).await; assert_eq!(count.load(Ordering::Relaxed), 1); Ok(()) } @@ -539,13 +544,19 @@ async fn test_ws_transport() { async fn test_h2_graceful_shutdown() -> io::Result<()> { let count = Arc::new(AtomicUsize::new(0)); let count2 = count.clone(); + let (tx, rx) = ::oneshot::channel(); + let tx = Arc::new(Mutex::new(Some(tx))); let srv = test_server(move || { + let tx = tx.clone(); let count = count2.clone(); HttpService::build() .h2(move |_| { let count = count.clone(); count.fetch_add(1, Ordering::Relaxed); + if count.load(Ordering::Relaxed) == 2 { + let _ = tx.lock().unwrap().take().unwrap().send(()); + } async move { sleep(Millis(1000)).await; count.fetch_sub(1, Ordering::Relaxed); @@ -566,7 +577,7 @@ async fn test_h2_graceful_shutdown() -> io::Result<()> { let _ = req.send().await.unwrap(); sleep(Millis(100000)).await; }); - sleep(Millis(150)).await; + let _ = rx.await; assert_eq!(count.load(Ordering::Relaxed), 2); let (tx, rx) = oneshot::channel(); @@ -574,8 +585,6 @@ async fn test_h2_graceful_shutdown() -> io::Result<()> { srv.stop().await; let _ = tx.send(()); }); - sleep(Millis(150)).await; - assert_eq!(count.load(Ordering::Relaxed), 2); let _ = rx.await; assert_eq!(count.load(Ordering::Relaxed), 0); diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index 44512500..0227573b 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -1,4 +1,4 @@ -use std::sync::{atomic::AtomicUsize, atomic::Ordering, Arc}; +use std::sync::{atomic::AtomicUsize, atomic::Ordering, Arc, Mutex}; use std::{io, io::Read, io::Write, net}; use futures_util::future::{self, FutureExt}; @@ -405,6 +405,36 @@ async fn test_http1_handle_not_consumed_payload() { assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); } +/// Handle payload errors (keep-alive, disconnects) +#[ntex::test] +async fn test_http1_handle_payload_errors() { + let count = Arc::new(AtomicUsize::new(0)); + let count2 = count.clone(); + + let srv = test_server(move || { + let count = count2.clone(); + HttpService::build().h1(move |mut req: Request| { + let count = count.clone(); + async move { + let mut pl = req.take_payload(); + let result = pl.recv().await; + if result.unwrap().is_err() { + count.fetch_add(1, Ordering::Relaxed); + } + Ok::<_, io::Error>(Response::Ok().finish()) + } + }) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = + stream.write_all(b"GET /test/tests/test HTTP/1.1\r\ncontent-length: 99999\r\n\r\n"); + sleep(Millis(250)).await; + drop(stream); + sleep(Millis(250)).await; + assert_eq!(count.load(Ordering::Acquire), 1); +} + #[ntex::test] async fn test_content_length() { let srv = test_server(|| { @@ -693,11 +723,12 @@ async fn test_h1_service_error() { assert_eq!(bytes, Bytes::from_static(b"error")); } -struct SetOnDrop(Arc); +struct SetOnDrop(Arc, Option<::oneshot::Sender<()>>); impl Drop for SetOnDrop { fn drop(&mut self) { self.0.fetch_add(1, Ordering::Relaxed); + let _ = self.1.take().unwrap().send(()); } } @@ -705,24 +736,28 @@ impl Drop for SetOnDrop { async fn test_h1_client_drop() -> io::Result<()> { let count = Arc::new(AtomicUsize::new(0)); let count2 = count.clone(); + let (tx, rx) = ::oneshot::channel(); + let tx = Arc::new(Mutex::new(Some(tx))); let srv = test_server(move || { + let tx = tx.clone(); let count = count2.clone(); HttpService::build().h1(move |req: Request| { + let tx = tx.clone(); let count = count.clone(); async move { - let _st = SetOnDrop(count); + let _st = SetOnDrop(count, tx.lock().unwrap().take()); assert!(req.peer_addr().is_some()); assert_eq!(req.version(), Version::HTTP_11); - sleep(Seconds(100)).await; + sleep(Millis(50000)).await; Ok::<_, io::Error>(Response::Ok().finish()) } }) }); - let result = timeout(Millis(100), srv.request(Method::GET, "/").send()).await; + let result = timeout(Millis(1500), srv.request(Method::GET, "/").send()).await; assert!(result.is_err()); - sleep(Millis(250)).await; + let _ = rx.await; assert_eq!(count.load(Ordering::Relaxed), 1); Ok(()) } @@ -731,12 +766,18 @@ async fn test_h1_client_drop() -> io::Result<()> { async fn test_h1_gracefull_shutdown() { let count = Arc::new(AtomicUsize::new(0)); let count2 = count.clone(); + let (tx, rx) = ::oneshot::channel(); + let tx = Arc::new(Mutex::new(Some(tx))); let srv = test_server(move || { + let tx = tx.clone(); let count = count2.clone(); HttpService::build().h1(move |_: Request| { let count = count.clone(); count.fetch_add(1, Ordering::Relaxed); + if count.load(Ordering::Relaxed) == 2 { + let _ = tx.lock().unwrap().take().unwrap().send(()); + } async move { sleep(Millis(1000)).await; count.fetch_sub(1, Ordering::Relaxed); @@ -751,7 +792,7 @@ async fn test_h1_gracefull_shutdown() { let mut stream2 = net::TcpStream::connect(srv.addr()).unwrap(); let _ = stream2.write_all(b"GET /index.html HTTP/1.1\r\n\r\n"); - sleep(Millis(150)).await; + let _ = rx.await; assert_eq!(count.load(Ordering::Relaxed), 2); let (tx, rx) = oneshot::channel(); @@ -759,8 +800,6 @@ async fn test_h1_gracefull_shutdown() { srv.stop().await; let _ = tx.send(()); }); - sleep(Millis(150)).await; - assert_eq!(count.load(Ordering::Relaxed), 2); let _ = rx.await; assert_eq!(count.load(Ordering::Relaxed), 0); @@ -770,12 +809,18 @@ async fn test_h1_gracefull_shutdown() { async fn test_h1_gracefull_shutdown_2() { let count = Arc::new(AtomicUsize::new(0)); let count2 = count.clone(); + let (tx, rx) = ::oneshot::channel(); + let tx = Arc::new(Mutex::new(Some(tx))); let srv = test_server(move || { + let tx = tx.clone(); let count = count2.clone(); HttpService::build().finish(move |_: Request| { let count = count.clone(); count.fetch_add(1, Ordering::Relaxed); + if count.load(Ordering::Relaxed) == 2 { + let _ = tx.lock().unwrap().take().unwrap().send(()); + } async move { sleep(Millis(1000)).await; count.fetch_sub(1, Ordering::Relaxed); @@ -790,17 +835,14 @@ async fn test_h1_gracefull_shutdown_2() { let mut stream2 = net::TcpStream::connect(srv.addr()).unwrap(); let _ = stream2.write_all(b"GET /index.html HTTP/1.1\r\n\r\n"); - sleep(Millis(150)).await; - assert_eq!(count.load(Ordering::Relaxed), 2); + let _ = rx.await; + assert_eq!(count.load(Ordering::Acquire), 2); let (tx, rx) = oneshot::channel(); rt::spawn(async move { srv.stop().await; let _ = tx.send(()); }); - sleep(Millis(150)).await; - assert_eq!(count.load(Ordering::Relaxed), 2); - let _ = rx.await; assert_eq!(count.load(Ordering::Relaxed), 0); } diff --git a/ntex/tests/server.rs b/ntex/tests/server.rs index 09432a99..8e97908d 100644 --- a/ntex/tests/server.rs +++ b/ntex/tests/server.rs @@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; #[cfg(feature = "tokio")] use std::{io, sync::Arc}; -use std::{io::Read, net, sync::mpsc, thread, time}; +use std::{io::Read, io::Write, net, sync::mpsc, thread, time}; use ntex::codec::BytesCodec; use ntex::io::Io; @@ -71,6 +71,7 @@ async fn test_listen() { #[ntex::test] #[cfg(unix)] +#[allow(clippy::unused_io_amount)] async fn test_run() { let addr = TestServer::unused_addr(); let (tx, rx) = mpsc::channel(); @@ -80,6 +81,7 @@ async fn test_run() { sys.run(move || { let srv = build() .backlog(100) + .workers(1) .disable_signals() .bind("test", addr, move |_| { fn_service(|io: Io| async move { @@ -90,6 +92,7 @@ async fn test_run() { }) }) .unwrap() + .set_tag("test", "SRV") .run(); let _ = tx.send((srv, ntex::rt::System::current())); Ok(()) @@ -99,6 +102,7 @@ async fn test_run() { let mut buf = [1u8; 4]; let mut conn = net::TcpStream::connect(addr).unwrap(); + conn.write(&b"test"[..]).unwrap(); let _ = conn.read_exact(&mut buf); assert_eq!(buf, b"test"[..]); diff --git a/ntex/tests/web_server.rs b/ntex/tests/web_server.rs index cbb7956d..a1ab4ace 100644 --- a/ntex/tests/web_server.rs +++ b/ntex/tests/web_server.rs @@ -1,6 +1,5 @@ use std::{future::Future, io, io::Read, io::Write, pin::Pin, task::Context, task::Poll}; -use brotli2::write::{BrotliDecoder, BrotliEncoder}; use flate2::read::GzDecoder; use flate2::write::{GzEncoder, ZlibDecoder, ZlibEncoder}; use flate2::Compression; @@ -318,36 +317,6 @@ async fn test_body_chunked_implicit() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[ntex::test] -async fn test_body_br_streaming() { - let srv = test::server_with(test::config().h1(), || { - App::new().wrap(Compress::new(ContentEncoding::Br)).service( - web::resource("/").route(web::to(move || async { - HttpResponse::Ok() - .streaming(TestBody::new(Bytes::from_static(STR.as_ref()), 24)) - })), - ) - }); - - let mut response = srv - .get("/") - .header(ACCEPT_ENCODING, "br") - .no_decompress() - .send() - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - - // decode br - let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); -} - #[ntex::test] async fn test_head_binary() { let srv = test::server_with(test::config().h1(), || { @@ -422,35 +391,6 @@ async fn test_body_deflate() { assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); } -#[ntex::test] -async fn test_body_brotli() { - let srv = test::server_with(test::config().h1(), || { - App::new().wrap(Compress::new(ContentEncoding::Br)).service( - web::resource("/") - .route(web::to(move || async { HttpResponse::Ok().body(STR) })), - ) - }); - - // client request - let mut response = srv - .get("/") - .header(ACCEPT_ENCODING, "br") - .no_decompress() - .send() - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - - // decode brotli - let mut e = BrotliDecoder::new(Vec::with_capacity(2048)); - e.write_all(bytes.as_ref()).unwrap(); - let dec = e.finish().unwrap(); - assert_eq!(Bytes::from(dec), Bytes::from_static(STR.as_ref())); -} - #[ntex::test] async fn test_encoding() { let srv = test::server_with(test::config().h1(), || { @@ -644,204 +584,6 @@ async fn test_reading_deflate_encoding_large_random() { assert_eq!(bytes, Bytes::from(data)); } -#[ntex::test] -async fn test_brotli_encoding() { - let srv = test::server_with(test::config().h1(), || { - App::new().service(web::resource("/").route(web::to(move |body: Bytes| async { - HttpResponse::Ok().body(body) - }))) - }); - - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(STR.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request - let request = srv - .post("/") - .header(CONTENT_ENCODING, "br") - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from_static(STR.as_ref())); -} - -#[ntex::test] -async fn test_brotli_encoding_large() { - let data = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(320_000) - .map(char::from) - .collect::(); - - let srv = test::server_with(test::config().h1(), || { - App::new().service( - web::resource("/") - .state(web::types::PayloadConfig::new(320_000)) - .route(web::to(move |body: Bytes| async { - HttpResponse::Ok().streaming(TestBody::new(body, 10240)) - })), - ) - }); - - let mut e = BrotliEncoder::new(Vec::new(), 5); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request - let request = srv - .post("/") - .header(CONTENT_ENCODING, "br") - .send_body(enc.clone()); - let mut response = request.await.unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().limit(320_000).await.unwrap(); - assert_eq!(bytes, Bytes::from(data)); -} - -#[cfg(feature = "openssl")] -#[ntex::test] -async fn test_brotli_encoding_large_openssl() { - // load ssl keys - use tls_openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; - - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("./tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("./tests/cert.pem") - .unwrap(); - - let data = STR.repeat(10); - let srv = test::server_with(test::config().openssl(builder.build()), move || { - App::new().service(web::resource("/").route(web::to(|bytes: Bytes| async { - HttpResponse::Ok() - .encoding(ContentEncoding::Identity) - .body(bytes) - }))) - }); - - // body - let mut e = BrotliEncoder::new(Vec::new(), 3); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request - let mut response = srv - .post("/") - .header(CONTENT_ENCODING, "br") - .send_body(enc) - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from(data)); -} - -#[cfg(feature = "openssl")] -#[ntex::test] -async fn test_brotli_encoding_large_openssl_h1() { - // load ssl keys - use tls_openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; - - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("./tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("./tests/cert.pem") - .unwrap(); - - let data = STR.repeat(10); - let srv = test::server_with(test::config().openssl(builder.build()).h1(), move || { - App::new().service(web::resource("/").route(web::to(|bytes: Bytes| async { - HttpResponse::Ok() - .encoding(ContentEncoding::Identity) - .body(bytes) - }))) - }); - - // body - let mut e = BrotliEncoder::new(Vec::new(), 3); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request - let mut response = srv - .post("/") - .header(CONTENT_ENCODING, "br") - .send_body(enc) - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from(data)); -} - -#[cfg(feature = "openssl")] -#[ntex::test] -async fn test_brotli_encoding_large_openssl_h2() { - // load ssl keys - use tls_openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; - - let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - builder - .set_private_key_file("./tests/key.pem", SslFiletype::PEM) - .unwrap(); - builder - .set_certificate_chain_file("./tests/cert.pem") - .unwrap(); - builder.set_alpn_select_callback(|_, protos| { - const H2: &[u8] = b"\x02h2"; - const H11: &[u8] = b"\x08http/1.1"; - if protos.windows(3).any(|window| window == H2) { - Ok(b"h2") - } else if protos.windows(9).any(|window| window == H11) { - Ok(b"http/1.1") - } else { - Err(AlpnError::NOACK) - } - }); - builder.set_alpn_protos(b"\x08http/1.1\x02h2").unwrap(); - - let data = STR.repeat(10); - let srv = test::server_with(test::config().openssl(builder.build()).h2(), move || { - App::new().service(web::resource("/").route(web::to(|bytes: Bytes| async { - HttpResponse::Ok() - .encoding(ContentEncoding::Identity) - .body(bytes) - }))) - }); - - // body - let mut e = BrotliEncoder::new(Vec::new(), 3); - e.write_all(data.as_ref()).unwrap(); - let enc = e.finish().unwrap(); - - // client request - let mut response = srv - .post("/") - .header(CONTENT_ENCODING, "br") - .send_body(enc) - .await - .unwrap(); - assert!(response.status().is_success()); - - // read response - let bytes = response.body().await.unwrap(); - assert_eq!(bytes, Bytes::from(data)); -} - #[cfg(all(feature = "rustls", feature = "openssl"))] #[ntex::test] async fn test_reading_deflate_encoding_large_random_rustls() { @@ -868,7 +610,7 @@ async fn test_reading_deflate_encoding_large_random_rustls() { // client request let req = srv .post("/") - .timeout(Millis(10_000)) + .timeout(Millis(30_000)) .header(CONTENT_ENCODING, "deflate") .send_stream(TestBody::new(Bytes::from(enc), 1024)); @@ -909,7 +651,7 @@ async fn test_reading_deflate_encoding_large_random_rustls_h1() { // client request let req = srv .post("/") - .timeout(Millis(10_000)) + .timeout(Millis(30_000)) .header(CONTENT_ENCODING, "deflate") .send_stream(TestBody::new(Bytes::from(enc), 1024)); @@ -950,7 +692,7 @@ async fn test_reading_deflate_encoding_large_random_rustls_h2() { // client request let req = srv .post("/") - .timeout(Millis(10_000)) + .timeout(Millis(30_000)) .header(CONTENT_ENCODING, "deflate") .send_stream(TestBody::new(Bytes::from(enc), 1024)); diff --git a/ntex/tests/web_ws.rs b/ntex/tests/web_ws.rs index 2187321d..20e1d324 100644 --- a/ntex/tests/web_ws.rs +++ b/ntex/tests/web_ws.rs @@ -21,6 +21,8 @@ async fn service(msg: ws::Frame) -> Result, io::Error> { #[ntex::test] async fn web_ws() { + let _ = env_logger::try_init(); + let srv = test::server(|| { App::new().service(web::resource("/").route(web::to( |req: HttpRequest| async move {