From 12d108c8c25c89645dfa8f66b556e410de5b92a6 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 31 Mar 2025 19:40:10 +0500 Subject: [PATCH] Refactor polling impl --- Cargo.toml | 3 + ntex-io/Cargo.toml | 2 +- ntex-io/src/tasks.rs | 12 +- ntex-net/Cargo.toml | 3 +- ntex-net/src/rt_polling/connect.rs | 7 +- ntex-net/src/rt_polling/driver.rs | 224 +++++++++++++++++++++-------- ntex-net/src/rt_polling/mod.rs | 58 ++++++++ ntex/tests/connect.rs | 3 +- ntex/tests/http_awc_client.rs | 9 +- 9 files changed, 238 insertions(+), 83 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 681247bd..e5006289 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,10 @@ ntex-compio = { path = "ntex-compio" } ntex-tokio = { path = "ntex-tokio" } ntex-neon = { git = "https://github.com/ntex-rs/neon.git" } +polling = { git = "https://github.com/fafhrd91/polling.git" } + #ntex-neon = { path = "../dev/neon" } +#polling = { path = "../dev/polling" } [workspace.dependencies] async-channel = "2" diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml index f55aa5d0..f7a54de3 100644 --- a/ntex-io/Cargo.toml +++ b/ntex-io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-io" -version = "2.11.1" +version = "2.11.2" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index 55f99416..3a078c18 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -739,19 +739,11 @@ impl IoContext { pub fn with_read_buf(&self, f: F) -> Poll<()> where - F: FnOnce(&mut BytesVec) -> Poll>, + F: FnOnce(&mut BytesVec, usize, usize) -> Poll>, { let inner = &self.0 .0; let (hw, lw) = self.0.memory_pool().read_params().unpack(); - let result = inner.buffer.with_read_source(&self.0, |buf| { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - f(buf) - }); + let result = inner.buffer.with_read_source(&self.0, |buf| f(buf, hw, lw)); // handle buffer changes match result { diff --git a/ntex-net/Cargo.toml b/ntex-net/Cargo.toml index 7c2ae447..12dec037 100644 --- a/ntex-net/Cargo.toml +++ b/ntex-net/Cargo.toml @@ -34,7 +34,7 @@ io-uring = ["ntex-neon/io-uring", "dep:io-uring", "socket2"] ntex-service = "3.3" ntex-bytes = "0.1" ntex-http = "0.1" -ntex-io = "2.11.1" +ntex-io = "2.11.2" ntex-rt = "0.4.25" ntex-util = "2.5" @@ -57,3 +57,4 @@ polling = { workspace = true, optional = true } [dev-dependencies] ntex = "2" +oneshot = "0.1" diff --git a/ntex-net/src/rt_polling/connect.rs b/ntex-net/src/rt_polling/connect.rs index a0e2bd83..3123fd16 100644 --- a/ntex-net/src/rt_polling/connect.rs +++ b/ntex-net/src/rt_polling/connect.rs @@ -1,7 +1,7 @@ use std::os::fd::{AsRawFd, RawFd}; use std::{cell::RefCell, io, rc::Rc, task::Poll}; -use ntex_neon::driver::{DriverApi, Event, Handler}; +use ntex_neon::driver::{DriverApi, Event, Handler, PollMode}; use ntex_neon::{syscall, Runtime}; use ntex_util::channel::oneshot::Sender; use slab::Slab; @@ -62,7 +62,9 @@ impl ConnectOps { let item = Item { fd, sender }; let id = self.0.connects.borrow_mut().insert(item); - self.0.api.attach(fd, id as u32, Some(Event::writable(0))); + self.0 + .api + .attach(fd, id as u32, Event::writable(0), PollMode::Oneshot); Ok(id) } } @@ -72,7 +74,6 @@ impl Handler for ConnectOpsBatcher { log::debug!("connect-fd is readable {:?}", id); let mut connects = self.inner.connects.borrow_mut(); - if connects.contains(id) { let item = connects.remove(id); if event.writable { diff --git a/ntex-net/src/rt_polling/driver.rs b/ntex-net/src/rt_polling/driver.rs index 38263fc6..6fd8d2e0 100644 --- a/ntex-net/src/rt_polling/driver.rs +++ b/ntex-net/src/rt_polling/driver.rs @@ -1,11 +1,11 @@ use std::os::fd::{AsRawFd, RawFd}; -use std::{cell::Cell, cell::RefCell, future::Future, io, mem, rc::Rc, task, task::Poll}; +use std::{cell::Cell, cell::RefCell, future::Future, io, mem, rc::Rc, task::Poll}; -use ntex_neon::driver::{DriverApi, Event, Handler}; +use ntex_neon::driver::{DriverApi, Event, Handler, PollMode}; use ntex_neon::{syscall, Runtime}; use slab::Slab; -use ntex_bytes::BufMut; +use ntex_bytes::{BufMut, BytesVec}; use ntex_io::IoContext; pub(crate) struct StreamCtl { @@ -16,15 +16,17 @@ pub(crate) struct StreamCtl { bitflags::bitflags! { #[derive(Copy, Clone, Debug)] struct Flags: u8 { - const RD = 0b0000_0001; - const WR = 0b0000_0010; + const RD = 0b0000_0001; + const WR = 0b0000_0010; + const ERR = 0b0000_0100; + const RDSH = 0b0000_1000; } } struct StreamItem { io: Option, fd: RawFd, - flags: Flags, + flags: Cell, ref_count: u16, context: IoContext, } @@ -46,6 +48,22 @@ impl StreamItem { fn tag(&self) -> &'static str { self.context.tag() } + + fn contains(&self, flag: Flags) -> bool { + self.flags.get().contains(flag) + } + + fn insert(&self, fl: Flags) { + let mut flags = self.flags.get(); + flags.insert(fl); + self.flags.set(flags); + } + + fn remove(&self, fl: Flags) { + let mut flags = self.flags.get(); + flags.remove(fl); + self.flags.set(flags); + } } impl StreamOps { @@ -75,7 +93,7 @@ impl StreamOps { context, io: Some(io), ref_count: 1, - flags: Flags::empty(), + flags: Cell::new(Flags::empty()), }; StreamCtl { id: streams.insert(item) as u32, @@ -86,7 +104,8 @@ impl StreamOps { self.0.api.attach( fd, stream.id, - Some(Event::new(0, false, false).with_interrupt()), + Event::new(0, false, false).with_interrupt(), + PollMode::Edge, ); stream } @@ -110,38 +129,38 @@ impl Handler for StreamOpsHandler { } log::debug!("{}: FD event {:?} event: {:?}", item.tag(), id, ev); - // handle HUP - if ev.is_interrupt() { - item.context.stopped(None); - close(id as u32, item, &self.inner.api, None, true); - return; - } - + let mut changed = false; let mut renew_ev = Event::new(0, false, false).with_interrupt(); + // handle read op if ev.readable { - let res = item.context.with_read_buf(|buf| { - let chunk = buf.chunk_mut(); - let result = task::ready!(syscall!( - break libc::read(item.fd, chunk.as_mut_ptr() as _, chunk.len()) - )); - if let Ok(size) = result { - log::debug!("{}: data {:?}, s: {:?}", item.tag(), item.fd, size); - unsafe { buf.advance_mut(size) }; - } - Poll::Ready(result) - }); + let res = item + .context + .with_read_buf(|buf, hw, lw| read(item, buf, hw, lw)); if res.is_pending() && item.context.is_read_ready() { renew_ev.readable = true; - item.flags.insert(Flags::RD); } else { - item.flags.remove(Flags::RD); + changed = true; + item.remove(Flags::RD); } - } else if item.flags.contains(Flags::RD) { + } else if item.contains(Flags::RD) { renew_ev.readable = true; } + // handle error + if ev.is_err() == Some(true) { + item.insert(Flags::ERR); + } + + // handle HUP + if ev.is_interrupt() { + item.context.stopped(None); + close(id as u32, item, &self.inner.api, None); + return; + } + + // handle write op if ev.writable { let result = item.context.with_write_buf(|buf| { log::debug!("{}: write {:?} s: {:?}", item.tag(), item.fd, buf.len()); @@ -149,15 +168,19 @@ impl Handler for StreamOpsHandler { }); if result.is_pending() { renew_ev.writable = true; - item.flags.insert(Flags::WR); } else { - item.flags.remove(Flags::WR); + changed = true; + item.remove(Flags::WR); } - } else if item.flags.contains(Flags::WR) { + } else if item.contains(Flags::WR) { renew_ev.writable = true; } - self.inner.api.modify(item.fd, id as u32, renew_ev); + if changed { + self.inner + .api + .modify(item.fd, id as u32, renew_ev, PollMode::Edge); + } // delayed drops if self.inner.delayd_drop.get() { @@ -173,7 +196,7 @@ impl Handler for StreamOpsHandler { item.fd, item.io.is_some() ); - close(id, &mut item, &self.inner.api, None, true); + close(id, &mut item, &self.inner.api, None); } } self.inner.delayd_drop.set(false); @@ -191,7 +214,7 @@ impl Handler for StreamOpsHandler { item.fd, err ); - close(id as u32, item, &self.inner.api, Some(err), false); + close(id as u32, item, &self.inner.api, Some(err)); } }) } @@ -209,19 +232,96 @@ impl StreamOpsInner { } } +fn read( + item: &StreamItem, + buf: &mut BytesVec, + hw: usize, + lw: usize, +) -> Poll> { + log::debug!( + "{}: reading fd ({:?}) flags: {:?}", + item.tag(), + item.fd, + item.context.flags() + ); + if item.contains(Flags::RDSH) { + return Poll::Ready(Ok(0)); + } + + let mut total = 0; + loop { + // make sure we've got room + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); + } + + let chunk = buf.chunk_mut(); + let chunk_len = chunk.len(); + + let result = + syscall!(break libc::read(item.fd, chunk.as_mut_ptr() as _, chunk.len())); + if let Poll::Ready(Ok(size)) = result { + unsafe { buf.advance_mut(size) }; + total += size; + //if size != 0 { + if size == chunk_len { + continue; + } + } + + log::debug!( + "{}: read fd ({:?}), s: {:?}, cap: {:?}, result: {:?}", + item.tag(), + item.fd, + total, + buf.remaining_mut(), + result + ); + + return match result { + Poll::Ready(Err(err)) => { + if total > 0 { + item.insert(Flags::ERR); + item.context.stopped(Some(err)); + Poll::Ready(Ok(total)) + } else { + Poll::Ready(Err(err)) + } + } + Poll::Ready(Ok(size)) => { + if size == 0 { + item.insert(Flags::RDSH); + item.context.stopped(None); + } + Poll::Ready(Ok(total)) + } + Poll::Pending => { + if total > 0 { + Poll::Ready(Ok(total)) + } else { + Poll::Pending + } + } + }; + } +} + fn close( id: u32, item: &mut StreamItem, api: &DriverApi, error: Option, - shutdown: bool, ) -> Option>> { if let Some(io) = item.io.take() { log::debug!("{}: Closing ({}), {:?}", item.tag(), id, item.fd); mem::forget(io); - if let Some(err) = error { + let shutdown = if let Some(err) = error { item.context.stopped(Some(err)); - } + false + } else { + !item.flags.get().intersects(Flags::ERR | Flags::RDSH) + }; let fd = item.fd; api.detach(fd, id); Some(ntex_rt::spawn_blocking(move || { @@ -240,7 +340,7 @@ impl StreamCtl { let id = self.id as usize; let fut = self.inner.with(|streams| { let item = &mut streams[id]; - close(self.id, item, &self.inner.api, None, false) + close(self.id, item, &self.inner.api, None) }); async move { if let Some(fut) = fut { @@ -263,48 +363,41 @@ impl StreamCtl { pub(crate) fn modify(&self, rd: bool, wr: bool) { self.inner.with(|streams| { let item = &mut streams[self.id as usize]; + if item.contains(Flags::ERR) { + return; + } log::debug!( - "{}: Modify interest ({}), {:?} rd: {:?}, wr: {:?}", + "{}: Modify interest ({}), {:?} rd: {:?}, wr: {:?}, flags: {:?}", item.tag(), self.id, item.fd, rd, - wr + wr, + item.flags ); + let mut changed = false; let mut event = Event::new(0, false, false).with_interrupt(); if rd { - if item.flags.contains(Flags::RD) { + if item.contains(Flags::RD) { event.readable = true; } else { - let res = item.context.with_read_buf(|buf| { - let chunk = buf.chunk_mut(); - let result = task::ready!(syscall!( - break libc::read(item.fd, chunk.as_mut_ptr() as _, chunk.len()) - )); - if let Ok(size) = result { - log::debug!( - "{}: read {:?}, s: {:?}", - item.tag(), - item.fd, - size - ); - unsafe { buf.advance_mut(size) }; - } - Poll::Ready(result) - }); + let res = item + .context + .with_read_buf(|buf, hw, lw| read(item, buf, hw, lw)); if res.is_pending() && item.context.is_read_ready() { + changed = true; event.readable = true; - item.flags.insert(Flags::RD); + item.insert(Flags::RD); } } } if wr { - if item.flags.contains(Flags::WR) { + if item.contains(Flags::WR) { event.writable = true; } else { let result = item.context.with_write_buf(|buf| { @@ -320,13 +413,18 @@ impl StreamCtl { }); if result.is_pending() { + changed = true; event.writable = true; - item.flags.insert(Flags::WR); + item.insert(Flags::WR); } } } - self.inner.api.modify(item.fd, self.id, event); + if changed { + self.inner + .api + .modify(item.fd, self.id, event, PollMode::Edge); + } }) } } @@ -357,7 +455,7 @@ impl Drop for StreamCtl { item.fd, item.io.is_some() ); - close(self.id, &mut item, &self.inner.api, None, true); + close(self.id, &mut item, &self.inner.api, None); } self.inner.streams.set(Some(streams)); } else { diff --git a/ntex-net/src/rt_polling/mod.rs b/ntex-net/src/rt_polling/mod.rs index b4fb928b..755fda0a 100644 --- a/ntex-net/src/rt_polling/mod.rs +++ b/ntex-net/src/rt_polling/mod.rs @@ -67,3 +67,61 @@ pub fn from_unix_stream(stream: std::os::unix::net::UnixStream) -> Result { Socket::from(stream), )?))) } + +#[cfg(test)] +mod tests { + use ntex::{io::Io, time::sleep, time::Millis, util::PoolId}; + use std::sync::{Arc, Mutex}; + + use crate::connect::Connect; + + const DATA: &[u8] = b"Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World \ + Hello World Hello World Hello World Hello World Hello World"; + + // #[ntex::test] + async fn idle_disconnect() { + PoolId::P5.set_read_params(24, 12); + let (tx, rx) = ::oneshot::channel(); + let tx = Arc::new(Mutex::new(Some(tx))); + + let server = ntex::server::test_server(move || { + let tx = tx.clone(); + ntex_service::fn_service(move |io: Io<_>| { + tx.lock().unwrap().take().unwrap().send(()).unwrap(); + + async move { + io.write(DATA).unwrap(); + sleep(Millis(250)).await; + io.close(); + Ok::<_, ()>(()) + } + }) + }); + + let msg = Connect::new(server.addr()); + let io = crate::connect::connect(msg).await.unwrap(); + io.set_memory_pool(PoolId::P5.into()); + rx.await.unwrap(); + + io.on_disconnect().await; + } +} diff --git a/ntex/tests/connect.rs b/ntex/tests/connect.rs index 5ecd51b7..523232a8 100644 --- a/ntex/tests/connect.rs +++ b/ntex/tests/connect.rs @@ -1,9 +1,8 @@ use std::{io, rc::Rc}; -use ntex::codec::BytesCodec; -use ntex::connect::Connect; use ntex::io::{types::PeerAddr, Io}; use ntex::service::{chain_factory, fn_service, Pipeline, ServiceFactory}; +use ntex::{codec::BytesCodec, connect::Connect}; use ntex::{server::build_test_server, server::test_server, time, util::Bytes}; #[cfg(feature = "rustls")] diff --git a/ntex/tests/http_awc_client.rs b/ntex/tests/http_awc_client.rs index bd4c7e0a..e80644e9 100644 --- a/ntex/tests/http_awc_client.rs +++ b/ntex/tests/http_awc_client.rs @@ -682,15 +682,18 @@ async fn client_read_until_eof() { for stream in lst.incoming() { if let Ok(mut stream) = stream { let mut b = [0; 1000]; - let _ = stream.read(&mut b).unwrap(); - let _ = stream + log::debug!("Reading request"); + let res = stream.read(&mut b).unwrap(); + log::debug!("Read {:?}", res); + let res = stream .write_all(b"HTTP/1.0 200 OK\r\nconnection: close\r\n\r\nwelcome!"); + log::debug!("Sent {:?}", res); } else { break; } } }); - sleep(Millis(300)).await; + sleep(Millis(500)).await; // client request let req = Client::build()