diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index 498e249d..a99d0411 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -98,17 +98,19 @@ impl IoState { } pub(super) fn io_stopped(&self, err: Option) { - if err.is_some() { - self.error.set(err); + if !self.flags.get().contains(Flags::IO_STOPPED) { + if err.is_some() { + self.error.set(err); + } + self.read_task.wake(); + self.write_task.wake(); + self.dispatch_task.wake(); + self.notify_disconnect(); + self.handle.take(); + self.insert_flags( + Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS, + ); } - self.read_task.wake(); - self.write_task.wake(); - self.dispatch_task.wake(); - self.notify_disconnect(); - self.handle.take(); - self.insert_flags( - Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS, - ); } /// Gracefully shutdown read and write io tasks diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index 31681a59..883ac7ee 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -537,9 +537,7 @@ impl IoContext { self.0.tag(), nbytes ); - if !inner.dispatch_task.wake_checked() { - log::error!("Dispatcher waker is not registered"); - } + inner.dispatch_task.wake(); } else { if nbytes >= hw { // read task is paused because of read back-pressure @@ -779,11 +777,7 @@ impl IoContext { self.0.tag(), nbytes ); - if !inner.dispatch_task.wake_checked() { - log::error!( - "{}: Dispatcher waker is not registered, bytes: {:?}, flags: {:?}", - self.0.tag(), status.nbytes, self.flags()); - } + inner.dispatch_task.wake(); } else { if nbytes >= hw { // read task is paused because of read back-pressure diff --git a/ntex-net/src/lib.rs b/ntex-net/src/lib.rs index f97cb50c..9d2a4387 100644 --- a/ntex-net/src/lib.rs +++ b/ntex-net/src/lib.rs @@ -14,7 +14,7 @@ cfg_if::cfg_if! { mod rt_impl; pub use self::rt_impl::{ from_tcp_stream, from_unix_stream, tcp_connect, tcp_connect_in, unix_connect, - unix_connect_in, + unix_connect_in, active_stream_ops }; } else if #[cfg(all(unix, feature = "neon"))] { #[path = "rt_polling/mod.rs"] diff --git a/ntex-net/src/rt_polling/driver.rs b/ntex-net/src/rt_polling/driver.rs index 88f95894..22c04f50 100644 --- a/ntex-net/src/rt_polling/driver.rs +++ b/ntex-net/src/rt_polling/driver.rs @@ -1,72 +1,50 @@ -use std::os::fd::{AsRawFd, RawFd}; -use std::{cell::Cell, cell::RefCell, future::Future, io, mem, rc::Rc, task::Poll}; +use std::os::fd::RawFd; +use std::{cell::Cell, cell::RefCell, future::Future, io, rc::Rc, task::Poll}; use ntex_neon::driver::{DriverApi, Event, Handler, PollMode}; use ntex_neon::{syscall, Runtime}; use slab::Slab; -use ntex_bytes::{BufMut, BytesVec}; +use ntex_bytes::BufMut; use ntex_io::IoContext; -pub(crate) struct StreamCtl { +pub(crate) struct StreamCtl { id: u32, - inner: Rc>, + inner: Rc, } bitflags::bitflags! { #[derive(Copy, Clone, Debug)] struct Flags: u8 { - const RD = 0b0000_0001; - const WR = 0b0000_0010; - const ERR = 0b0000_0100; - const RDSH = 0b0000_1000; + const RD = 0b0000_0001; + const WR = 0b0000_0010; + const RDSH = 0b0000_0100; + const FAILED = 0b0000_1000; + const CLOSED = 0b0001_0000; } } -struct StreamItem { - io: Option, +struct StreamItem { fd: RawFd, - flags: Cell, + flags: Flags, ref_count: u16, context: IoContext, } -pub(crate) struct StreamOps(Rc>); +pub(crate) struct StreamOps(Rc); -struct StreamOpsHandler { - inner: Rc>, +struct StreamOpsHandler { + inner: Rc, } -struct StreamOpsInner { +struct StreamOpsInner { api: DriverApi, delayd_drop: Cell, feed: RefCell>, - streams: Cell>>>>, + streams: Cell>>>, } -impl StreamItem { - fn tag(&self) -> &'static str { - self.context.tag() - } - - fn contains(&self, flag: Flags) -> bool { - self.flags.get().contains(flag) - } - - fn insert(&self, fl: Flags) { - let mut flags = self.flags.get(); - flags.insert(fl); - self.flags.set(flags); - } - - fn remove(&self, fl: Flags) { - let mut flags = self.flags.get(); - flags.remove(fl); - self.flags.set(flags); - } -} - -impl StreamOps { +impl StreamOps { pub(crate) fn current() -> Self { Runtime::value(|rt| { let mut inner = None; @@ -89,15 +67,13 @@ impl StreamOps { Self::current().0.with(|streams| streams.len()) } - pub(crate) fn register(&self, io: T, context: IoContext) -> StreamCtl { - let fd = io.as_raw_fd(); + pub(crate) fn register(&self, fd: RawFd, context: IoContext) -> StreamCtl { let stream = self.0.with(move |streams| { let item = StreamItem { fd, context, - io: Some(io), ref_count: 1, - flags: Cell::new(Flags::empty()), + flags: Flags::empty(), }; StreamCtl { id: streams.insert(item) as u32, @@ -115,72 +91,61 @@ impl StreamOps { } } -impl Clone for StreamOps { +impl Clone for StreamOps { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl Handler for StreamOpsHandler { +impl Handler for StreamOpsHandler { fn event(&mut self, id: usize, ev: Event) { self.inner.with(|streams| { if !streams.contains(id) { return; } let item = &mut streams[id]; - if item.io.is_none() || item.contains(Flags::ERR) { - item.context.stopped(None); - return; - } + log::debug!("{}: FD event {:?} event: {:?}", item.tag(), id, ev); - let mut renew_ev = Event::new(0, false, false).with_interrupt(); - - // handle read op + let mut renew = Event::new(0, false, false).with_interrupt(); if ev.readable { - let res = item - .context - .with_read_buf(|buf, hw, lw| read(item, buf, hw, lw)); - + let res = item.read(); if res.is_pending() && item.context.is_read_ready() { - renew_ev.readable = true; + renew.readable = true; + item.flags.insert(Flags::RD); } else { - item.remove(Flags::RD); + item.flags.remove(Flags::RD); } - } else if item.contains(Flags::RD) { - renew_ev.readable = true; + } else if item.flags.contains(Flags::RD) { + renew.readable = true; } - // handle HUP - if ev.is_interrupt() && !item.contains(Flags::ERR) { - item.context.stopped(None); - close(id as u32, item, &self.inner.api); - return; - } - - // handle error - if ev.is_err() == Some(true) || ev.is_interrupt() { - item.insert(Flags::ERR); - } - - // handle write op if ev.writable { let result = item.context.with_write_buf(|buf| { log::debug!("{}: write {:?} s: {:?}", item.tag(), item.fd, buf.len()); syscall!(break libc::write(item.fd, buf[..].as_ptr() as _, buf.len())) }); if result.is_pending() { - renew_ev.writable = true; + renew.writable = true; + item.flags.insert(Flags::WR); } else { - item.remove(Flags::WR); + item.flags.remove(Flags::WR); } - } else if item.contains(Flags::WR) { - renew_ev.writable = true; + } else if item.flags.contains(Flags::WR) { + renew.writable = true; } - self.inner - .api - .modify(item.fd, id as u32, renew_ev, PollMode::Oneshot); + // handle HUP + if ev.is_interrupt() { + item.close(id as u32, &self.inner.api, None, false); + return; + } + + if !item.flags.contains(Flags::CLOSED | Flags::FAILED) { + self.inner + .api + .modify(item.fd, id as u32, renew, PollMode::Oneshot); + } // delayed drops if self.inner.delayd_drop.get() { @@ -190,14 +155,12 @@ impl Handler for StreamOpsHandler { if item.ref_count == 0 { let mut item = streams.remove(id as usize); log::debug!( - "{}: Drop ({}), {:?}, has-io: {}", + "{}: Drop ({:?}), flags: {:?}", item.tag(), - id, item.fd, - item.io.is_some() + item.flags ); - item.context.stopped(None); - close(id, &mut item, &self.inner.api); + item.close(id, &self.inner.api, None, true); } } self.inner.delayd_drop.set(false); @@ -215,18 +178,16 @@ impl Handler for StreamOpsHandler { item.fd, err ); - item.insert(Flags::ERR); - item.context.stopped(Some(err)); - close(id as u32, item, &self.inner.api); + item.close(id as u32, &self.inner.api, Some(err), false); } }) } } -impl StreamOpsInner { +impl StreamOpsInner { fn with(&self, f: F) -> R where - F: FnOnce(&mut Slab>) -> R, + F: FnOnce(&mut Slab) -> R, { let mut streams = self.streams.take().unwrap(); let result = f(&mut streams); @@ -235,110 +196,112 @@ impl StreamOpsInner { } } -fn read( - item: &StreamItem, - buf: &mut BytesVec, - hw: usize, - lw: usize, -) -> Poll> { - log::debug!( - "{}: reading fd ({:?}) flags: {:?}", - item.tag(), - item.fd, - item.context.flags() - ); - if item.contains(Flags::RDSH) { - return Poll::Ready(Ok(0)); +impl StreamItem { + fn tag(&self) -> &'static str { + self.context.tag() } - let mut total = 0; - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); + fn read(&mut self) -> Poll<()> { + let mut flags = self.flags; + let result = self.context.with_read_buf(|buf, hw, lw| { + // prev call result is 0 + if flags.contains(Flags::RDSH) { + return Poll::Ready(Ok(0)); + } + + let mut total = 0; + loop { + // make sure we've got room + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); + } + + let chunk = buf.chunk_mut(); + let chunk_len = chunk.len(); + let chunk_ptr = chunk.as_mut_ptr(); + + let result = + syscall!(break libc::read(self.fd, chunk_ptr as _, chunk.len())); + if let Poll::Ready(Ok(size)) = result { + unsafe { buf.advance_mut(size) }; + total += size; + if size == chunk_len { + continue; + } + } + + log::debug!( + "{}: read fd ({:?}), s: {:?}, cap: {:?}, result: {:?}", + self.tag(), + self.fd, + total, + buf.remaining_mut(), + result + ); + + return match result { + Poll::Ready(Err(err)) => { + flags.insert(Flags::FAILED); + if total > 0 { + self.context.stopped(Some(err)); + Poll::Ready(Ok(total)) + } else { + Poll::Ready(Err(err)) + } + } + Poll::Ready(Ok(size)) => { + if size == 0 { + flags.insert(Flags::RDSH); + } + Poll::Ready(Ok(total)) + } + Poll::Pending => { + if total > 0 { + Poll::Ready(Ok(total)) + } else { + Poll::Pending + } + } + }; + } + }); + self.flags = flags; + result + } + + fn close( + &mut self, + id: u32, + api: &DriverApi, + error: Option, + shutdown: bool, + ) -> Option>> { + if !self.flags.contains(Flags::CLOSED) { + log::debug!("{}: Closing ({}), {:?}", self.tag(), id, self.fd); + self.flags.insert(Flags::CLOSED); + self.context.stopped(error); + + let fd = self.fd; + api.detach(fd, id); + Some(ntex_rt::spawn_blocking(move || { + if shutdown { + let _ = syscall!(libc::shutdown(fd, libc::SHUT_RDWR)); + } + syscall!(libc::close(fd)) + })) + } else { + None } - - let chunk = buf.chunk_mut(); - let chunk_len = chunk.len(); - let chunk_ptr = chunk.as_mut_ptr(); - - let result = syscall!(break libc::read(item.fd, chunk_ptr as _, chunk.len())); - if let Poll::Ready(Ok(size)) = result { - unsafe { buf.advance_mut(size) }; - total += size; - if size == chunk_len { - continue; - } - } - - log::debug!( - "{}: read fd ({:?}), s: {:?}, cap: {:?}, result: {:?}", - item.tag(), - item.fd, - total, - buf.remaining_mut(), - result - ); - - return match result { - Poll::Ready(Err(err)) => { - item.insert(Flags::ERR); - if total > 0 { - item.context.stopped(Some(err)); - Poll::Ready(Ok(total)) - } else { - Poll::Ready(Err(err)) - } - } - Poll::Ready(Ok(size)) => { - if size == 0 { - item.insert(Flags::RDSH); - item.context.stopped(None); - } - Poll::Ready(Ok(total)) - } - Poll::Pending => { - if total > 0 { - Poll::Ready(Ok(total)) - } else { - Poll::Pending - } - } - }; } } -fn close( - id: u32, - item: &mut StreamItem, - api: &DriverApi, -) -> Option>> { - if let Some(io) = item.io.take() { - log::debug!("{}: Closing ({}), {:?}", item.tag(), id, item.fd); - mem::forget(io); - let fd = item.fd; - let shutdown = !item.flags.get().intersects(Flags::ERR | Flags::RDSH); - api.detach(fd, id); - Some(ntex_rt::spawn_blocking(move || { - if shutdown { - let _ = syscall!(libc::shutdown(fd, libc::SHUT_RDWR)); - } - syscall!(libc::close(fd)) - })) - } else { - None - } -} - -impl StreamCtl { +impl StreamCtl { pub(crate) fn close(self) -> impl Future> { let id = self.id as usize; - let fut = self.inner.with(|streams| { - let item = &mut streams[id]; - item.context.stopped(None); - close(self.id, item, &self.inner.api) - }); + let fut = self + .inner + .with(|streams| streams[id].close(self.id, &self.inner.api, None, true)); async move { if let Some(fut) = fut { fut.await @@ -349,52 +312,42 @@ impl StreamCtl { } } - pub(crate) fn with_io(&self, f: F) -> R - where - F: FnOnce(Option<&T>) -> R, - { - self.inner - .with(|streams| f(streams[self.id as usize].io.as_ref())) - } - pub(crate) fn modify(&self, rd: bool, wr: bool) -> bool { self.inner.with(|streams| { let item = &mut streams[self.id as usize]; - if item.io.is_none() || item.contains(Flags::ERR) { + if item.flags.contains(Flags::CLOSED) { return false; } log::debug!( - "{}: Modify interest ({}), {:?} rd: {:?}, wr: {:?}, flags: {:?}", + "{}: Modify interest ({:?}) rd: {:?}, wr: {:?}", item.tag(), - self.id, item.fd, rd, - wr, - item.flags + wr ); let mut changed = false; let mut event = Event::new(0, false, false).with_interrupt(); if rd { - if item.contains(Flags::RD) { + if item.flags.contains(Flags::RD) { event.readable = true; } else { - let res = item - .context - .with_read_buf(|buf, hw, lw| read(item, buf, hw, lw)); - + let res = item.read(); if res.is_pending() && item.context.is_read_ready() { changed = true; event.readable = true; - item.insert(Flags::RD); + item.flags.insert(Flags::RD); } } + } else if item.flags.contains(Flags::RD) { + changed = true; + item.flags.remove(Flags::RD); } if wr { - if item.contains(Flags::WR) { + if item.flags.contains(Flags::WR) { event.writable = true; } else { let result = item.context.with_write_buf(|buf| { @@ -412,12 +365,15 @@ impl StreamCtl { if result.is_pending() { changed = true; event.writable = true; - item.insert(Flags::WR); + item.flags.insert(Flags::WR); } } + } else if item.flags.contains(Flags::WR) { + changed = true; + item.flags.remove(Flags::WR); } - if changed { + if changed && !item.flags.contains(Flags::CLOSED | Flags::FAILED) { self.inner .api .modify(item.fd, self.id, event, PollMode::Oneshot); @@ -427,7 +383,7 @@ impl StreamCtl { } } -impl Clone for StreamCtl { +impl Clone for StreamCtl { fn clone(&self) -> Self { self.inner.with(|streams| { streams[self.id as usize].ref_count += 1; @@ -439,7 +395,7 @@ impl Clone for StreamCtl { } } -impl Drop for StreamCtl { +impl Drop for StreamCtl { fn drop(&mut self) { if let Some(mut streams) = self.inner.streams.take() { let id = self.id as usize; @@ -447,14 +403,12 @@ impl Drop for StreamCtl { if streams[id].ref_count == 0 { let mut item = streams.remove(id); log::debug!( - "{}: Drop io ({}), {:?}, has-io: {}", + "{}: Drop io ({:?}), flags: {:?}", item.tag(), - self.id, item.fd, - item.io.is_some() + item.flags ); - item.context.stopped(None); - close(self.id, &mut item, &self.inner.api); + item.close(self.id, &self.inner.api, None, true); } self.inner.streams.set(Some(streams)); } else { diff --git a/ntex-net/src/rt_polling/io.rs b/ntex-net/src/rt_polling/io.rs index 254343e5..2cb57323 100644 --- a/ntex-net/src/rt_polling/io.rs +++ b/ntex-net/src/rt_polling/io.rs @@ -1,4 +1,4 @@ -use std::{any, future::poll_fn, task::Poll}; +use std::{any, future::poll_fn, mem, os::fd::AsRawFd, task::Poll}; use ntex_io::{ types, Handle, IoContext, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus, @@ -12,11 +12,10 @@ impl IoStream for super::TcpStream { fn start(self, read: ReadContext, _: WriteContext) -> Option> { let io = self.0; let context = read.context(); - let ctl = StreamOps::current().register(io, context.clone()); - let ctl2 = ctl.clone(); + let ctl = StreamOps::current().register(io.as_raw_fd(), context.clone()); spawn(async move { run(ctl, context).await }); - Some(Box::new(HandleWrapper(ctl2))) + Some(Box::new(HandleWrapper(Some(io)))) } } @@ -24,19 +23,20 @@ impl IoStream for super::UnixStream { fn start(self, read: ReadContext, _: WriteContext) -> Option> { let io = self.0; let context = read.context(); - let ctl = StreamOps::current().register(io, context.clone()); + let ctl = StreamOps::current().register(io.as_raw_fd(), context.clone()); spawn(async move { run(ctl, context).await }); + mem::forget(io); None } } -struct HandleWrapper(StreamCtl); +struct HandleWrapper(Option); impl Handle for HandleWrapper { fn query(&self, id: any::TypeId) -> Option> { if id == any::TypeId::of::() { - let addr = self.0.with_io(|io| io.and_then(|io| io.peer_addr().ok())); + let addr = self.0.as_ref().unwrap().peer_addr().ok(); if let Some(addr) = addr.and_then(|addr| addr.as_socket()) { return Some(Box::new(types::PeerAddr(addr))); } @@ -45,13 +45,19 @@ impl Handle for HandleWrapper { } } +impl Drop for HandleWrapper { + fn drop(&mut self) { + mem::forget(self.0.take()); + } +} + #[derive(Copy, Clone, Debug, PartialEq, Eq)] enum Status { Shutdown, Terminate, } -async fn run(ctl: StreamCtl, context: IoContext) { +async fn run(ctl: StreamCtl, context: IoContext) { // Handle io read readiness let st = poll_fn(|cx| { let mut modify = false; @@ -98,8 +104,9 @@ async fn run(ctl: StreamCtl, context: IoContext) { .await; if st != Status::Terminate { - ctl.modify(false, true); - context.shutdown(st == Status::Shutdown).await; + if ctl.modify(false, true) { + context.shutdown(st == Status::Shutdown).await; + } } context.stopped(ctl.close().await.err()); } diff --git a/ntex-net/src/rt_polling/mod.rs b/ntex-net/src/rt_polling/mod.rs index c17a30d2..95f312b1 100644 --- a/ntex-net/src/rt_polling/mod.rs +++ b/ntex-net/src/rt_polling/mod.rs @@ -71,7 +71,7 @@ pub fn from_unix_stream(stream: std::os::unix::net::UnixStream) -> Result { #[doc(hidden)] /// Get number of active Io objects pub fn active_stream_ops() -> usize { - self::driver::StreamOps::::active_ops() + self::driver::StreamOps::active_ops() } #[cfg(all(target_os = "linux", feature = "neon"))] diff --git a/ntex-net/src/rt_uring/driver.rs b/ntex-net/src/rt_uring/driver.rs index 7115b9a7..2f76509c 100644 --- a/ntex-net/src/rt_uring/driver.rs +++ b/ntex-net/src/rt_uring/driver.rs @@ -124,6 +124,10 @@ impl StreamOps { } } + pub(crate) fn active_ops() -> usize { + Self::current().with(|st| st.streams.len()) + } + fn with(&self, f: F) -> R where F: FnOnce(&mut StreamOpsStorage) -> R, diff --git a/ntex-net/src/rt_uring/mod.rs b/ntex-net/src/rt_uring/mod.rs index 41016d09..6ae53b99 100644 --- a/ntex-net/src/rt_uring/mod.rs +++ b/ntex-net/src/rt_uring/mod.rs @@ -64,3 +64,9 @@ pub fn from_unix_stream(stream: std::os::unix::net::UnixStream) -> Result { Socket::from(stream), )?))) } + +#[doc(hidden)] +/// Get number of active Io objects +pub fn active_stream_ops() -> usize { + self::driver::StreamOps::::active_ops() +}