Refactor ntex-io (#164)

* Refactor Io and Filter types
This commit is contained in:
Nikolay Kim 2023-01-23 14:42:00 +06:00 committed by GitHub
parent 8cbd8758a5
commit 83d05d81ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 1615 additions and 1495 deletions

View file

@ -1,11 +1,13 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::{any, cmp, future::Future, io, pin::Pin, task::Context, task::Poll};
use std::{cell::Cell, sync::Arc};
use std::{any, cell::Cell, cmp, io, sync::Arc, task::Context, task::Poll};
use ntex_bytes::{BytesVec, PoolRef};
use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::time::Millis;
use ntex_bytes::PoolRef;
use ntex_io::{
Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf,
WriteStatus,
};
use ntex_util::{future::BoxFuture, time::Millis};
use tls_rust::{Certificate, ClientConfig, ServerConfig, ServerName};
mod accept;
@ -25,33 +27,33 @@ pub struct PeerCert(pub Certificate);
pub struct PeerCertChain(pub Vec<Certificate>);
/// An implementation of SSL streams
pub struct TlsFilter<F = Base> {
inner: InnerTlsFilter<F>,
pub struct TlsFilter {
inner: InnerTlsFilter,
}
enum InnerTlsFilter<F> {
Server(TlsServerFilter<F>),
Client(TlsClientFilter<F>),
enum InnerTlsFilter {
Server(TlsServerFilter),
Client(TlsClientFilter),
}
impl<F> TlsFilter<F> {
fn new_server(server: TlsServerFilter<F>) -> Self {
impl TlsFilter {
fn new_server(server: TlsServerFilter) -> Self {
TlsFilter {
inner: InnerTlsFilter::Server(server),
}
}
fn new_client(client: TlsClientFilter<F>) -> Self {
fn new_client(client: TlsClientFilter) -> Self {
TlsFilter {
inner: InnerTlsFilter::Client(client),
}
}
fn server(&self) -> &TlsServerFilter<F> {
fn server(&self) -> &TlsServerFilter {
match self.inner {
InnerTlsFilter::Server(ref server) => server,
_ => unreachable!(),
}
}
fn client(&self) -> &TlsClientFilter<F> {
fn client(&self) -> &TlsClientFilter {
match self.inner {
InnerTlsFilter::Client(ref server) => server,
_ => unreachable!(),
@ -59,7 +61,7 @@ impl<F> TlsFilter<F> {
}
}
impl<F: Filter> Filter for TlsFilter<F> {
impl FilterLayer for TlsFilter {
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
match self.inner {
@ -69,10 +71,10 @@ impl<F: Filter> Filter for TlsFilter<F> {
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result<Poll<()>> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.poll_shutdown(),
InnerTlsFilter::Client(ref f) => f.poll_shutdown(),
InnerTlsFilter::Server(ref f) => f.shutdown(buf),
InnerTlsFilter::Client(ref f) => f.shutdown(buf),
}
}
@ -93,42 +95,18 @@ impl<F: Filter> Filter for TlsFilter<F> {
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.get_read_buf(),
InnerTlsFilter::Client(ref f) => f.get_read_buf(),
InnerTlsFilter::Server(ref f) => f.process_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_read_buf(buf),
}
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.get_write_buf(),
InnerTlsFilter::Client(ref f) => f.get_write_buf(),
}
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.release_read_buf(buf),
}
}
#[inline]
fn process_read_buf(&self, io: &IoRef, nb: usize) -> io::Result<(usize, usize)> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_read_buf(io, nb),
InnerTlsFilter::Client(ref f) => f.process_read_buf(io, nb),
}
}
#[inline]
fn release_write_buf(&self, src: BytesVec) -> Result<(), io::Error> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_write_buf(src),
InnerTlsFilter::Client(ref f) => f.release_write_buf(src),
InnerTlsFilter::Server(ref f) => f.process_write_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_write_buf(buf),
}
}
}
@ -172,10 +150,10 @@ impl Clone for TlsAcceptor {
}
impl<F: Filter> FilterFactory<F> for TlsAcceptor {
type Filter = TlsFilter<F>;
type Filter = TlsFilter;
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, io::Error>>>>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg.clone();
@ -227,10 +205,10 @@ impl Clone for TlsConnectorConfigured {
}
impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
type Filter = TlsFilter<F>;
type Filter = TlsFilter;
type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, io::Error>>>>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg;
@ -240,44 +218,31 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
}
}
pub(crate) struct IoInner<F> {
filter: F,
pub(crate) struct IoInner {
pool: PoolRef,
read_buf: Cell<Option<BytesVec>>,
write_buf: Cell<Option<BytesVec>>,
handshake: Cell<bool>,
}
pub(crate) struct Wrapper<'a, F>(&'a IoInner<F>);
pub(crate) struct Wrapper<'a, 'b>(&'a IoInner, &'a mut WriteBuf<'b>);
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut read_buf) = self.0.filter.get_read_buf() {
self.1.with_read_buf(|buf| {
let read_buf = buf.get_src();
let len = cmp::min(read_buf.len(), dst.len());
let result = if len > 0 {
if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
};
self.0.filter.release_read_buf(read_buf);
result
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
}
})
}
}
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
impl<'a, 'b> io::Write for Wrapper<'a, 'b> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.0.filter.get_write_buf() {
buf.reserve(src.len());
buf
} else {
BytesVec::with_capacity_in(src.len(), self.0.pool)
};
buf.extend_from_slice(src);
self.0.filter.release_write_buf(buf)?;
self.1.get_dst().extend_from_slice(src);
Ok(src.len())
}