Update Filter trait usage

This commit is contained in:
Nikolay Kim 2022-01-12 22:08:50 +06:00
parent 1df005f53f
commit b49a5ed195
6 changed files with 203 additions and 275 deletions

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.1.2] - 2022-01-12
* Update Filter trait usage
## [0.1.1] - 2022-01-10 ## [0.1.1] - 2022-01-10
* Remove usage of ntex::io::Boxed types * Remove usage of ntex::io::Boxed types

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-tls" name = "ntex-tls"
version = "0.1.1" version = "0.1.2"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL" description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -26,8 +26,8 @@ rustls = ["tls_rust"]
[dependencies] [dependencies]
ntex-bytes = "0.1.9" ntex-bytes = "0.1.9"
ntex-io = "0.1.2" ntex-io = "0.1.3"
ntex-util = "0.1.8" ntex-util = "0.1.9"
ntex-service = "0.3.1" ntex-service = "0.3.1"
pin-project-lite = "0.2" pin-project-lite = "0.2"

View file

@ -27,25 +27,26 @@ pub struct PeerCertChain(pub Vec<X509>);
/// An implementation of SSL streams /// An implementation of SSL streams
pub struct SslFilter<F = Base> { pub struct SslFilter<F = Base> {
inner: RefCell<SslStream<IoInner<F>>>, inner: RefCell<SslStream<IoInner<F>>>,
pool: PoolRef,
handshake: Cell<bool>, handshake: Cell<bool>,
read_buf: Cell<Option<BytesMut>>,
} }
struct IoInner<F> { struct IoInner<F> {
inner: F, inner: F,
pool: PoolRef, pool: PoolRef,
read_buf: Option<BytesMut>,
write_buf: Option<BytesMut>, write_buf: Option<BytesMut>,
} }
impl<F: Filter> io::Read for IoInner<F> { impl<F: Filter> io::Read for IoInner<F> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(ref mut buf) = self.read_buf { if let Some(mut buf) = self.inner.get_read_buf() {
if buf.is_empty() { if buf.is_empty() {
buf.clear();
Err(io::Error::from(io::ErrorKind::WouldBlock)) Err(io::Error::from(io::ErrorKind::WouldBlock))
} else { } else {
let len = cmp::min(buf.len(), dst.len()); let len = cmp::min(buf.len(), dst.len());
dst[..len].copy_from_slice(&buf.split_to(len)); dst[..len].copy_from_slice(&buf.split_to(len));
self.inner.release_read_buf(buf);
Ok(len) Ok(len)
} }
} else { } else {
@ -139,70 +140,54 @@ impl<F: Filter> Filter for SslFilter<F> {
#[inline] #[inline]
fn get_read_buf(&self) -> Option<BytesMut> { fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() { self.read_buf.take()
if !buf.is_empty() {
return Some(buf);
}
}
None
} }
#[inline] #[inline]
fn get_write_buf(&self) -> Option<BytesMut> { fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().write_buf.take() { self.inner.borrow_mut().get_mut().write_buf.take()
if !buf.is_empty() {
return Some(buf);
}
}
None
} }
fn release_read_buf( #[inline]
&self, fn release_read_buf(&self, buf: BytesMut) {
io: &IoRef, self.read_buf.set(Some(buf));
src: BytesMut, }
dst: &mut Option<BytesMut>,
nbytes: usize, fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
) -> io::Result<usize> { // ask inner filter to process read buf
// store to read_buf match self
let pool = { .inner
let mut inner = self.inner.borrow_mut(); .borrow_mut()
let mut dst = None;
let result = inner
.get_ref() .get_ref()
.inner .inner
.release_read_buf(io, src, &mut dst, nbytes); .process_read_buf(io, nbytes)
if let Err(err) = result { {
io.want_shutdown(Some(err)); Err(err) => io.want_shutdown(Some(err)),
Ok((n, 0)) => return Ok((n, 0)),
Ok((_, _)) => (),
} }
if dst.is_some() {
inner.get_mut().read_buf = dst;
inner.get_ref().pool
} else {
return Ok(0);
}
};
let (hw, lw) = pool.read_params().unpack();
// get inner filter buffer // get processed buffer
if dst.is_none() { let mut dst = if let Some(dst) = self.get_read_buf() {
*dst = Some(pool.get_read_buf()); dst
} } else {
let buf = dst.as_mut().unwrap(); self.pool.get_read_buf()
};
let (hw, lw) = self.pool.read_params().unpack();
let mut new_bytes = 0; let mut new_bytes = 0;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = dst.remaining_mut();
if remaining < lw { if remaining < lw {
buf.reserve(hw - remaining); dst.reserve(hw - remaining);
} }
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *buf.chunk_mut()) }; let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let ssl_result = self.inner.borrow_mut().ssl_read(chunk); let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
return match ssl_result { let result = match ssl_result {
Ok(v) => { Ok(v) => {
unsafe { buf.advance_mut(v) }; unsafe { dst.advance_mut(v) };
new_bytes += v; new_bytes += v;
continue; continue;
} }
@ -216,14 +201,16 @@ impl<F: Filter> Filter for SslFilter<F> {
self.handshake.set(false); self.handshake.set(false);
} }
} }
Ok(new_bytes) Ok((dst.len(), new_bytes))
} }
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => { Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
io.want_shutdown(None); io.want_shutdown(None);
Ok(new_bytes) Ok((dst.len(), new_bytes))
} }
Err(e) => Err(map_to_ioerr(e)), Err(e) => Err(map_to_ioerr(e)),
}; };
self.release_read_buf(dst);
return result;
} }
} }
@ -299,18 +286,18 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
let ssl = ctx_result.map_err(map_to_ioerr)?; let ssl = ctx_result.map_err(map_to_ioerr)?;
let pool = st.memory_pool(); let pool = st.memory_pool();
let st = st.map_filter(|inner: F| { let st = st.map_filter(|inner: F| {
let read_buf = inner.get_read_buf();
let inner = IoInner { let inner = IoInner {
pool, pool,
inner, inner,
read_buf,
write_buf: None, write_buf: None,
}; };
let ssl_stream = ssl::SslStream::new(ssl, inner)?; let ssl_stream = ssl::SslStream::new(ssl, inner)?;
Ok::<_, Box<dyn Error>>(SslFilter { Ok::<_, Box<dyn Error>>(SslFilter {
inner: RefCell::new(ssl_stream), pool,
read_buf: Cell::new(None),
handshake: Cell::new(true), handshake: Cell::new(true),
inner: RefCell::new(ssl_stream),
}) })
})?; })?;
@ -352,18 +339,18 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
let ssl = self.ssl; let ssl = self.ssl;
let pool = st.memory_pool(); let pool = st.memory_pool();
let st = st.map_filter(|inner: F| { let st = st.map_filter(|inner: F| {
let read_buf = inner.get_read_buf();
let inner = IoInner { let inner = IoInner {
pool, pool,
inner, inner,
read_buf,
write_buf: None, write_buf: None,
}; };
let ssl_stream = ssl::SslStream::new(ssl, inner)?; let ssl_stream = ssl::SslStream::new(ssl, inner)?;
Ok::<_, Box<dyn Error>>(SslFilter { Ok::<_, Box<dyn Error>>(SslFilter {
inner: RefCell::new(ssl_stream), pool,
read_buf: Cell::new(None),
handshake: Cell::new(true), handshake: Cell::new(true),
inner: RefCell::new(ssl_stream),
}) })
})?; })?;

View file

@ -1,13 +1,13 @@
//! An implementation of SSL streams for ntex backed by OpenSSL //! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite}; use std::io::{self, Read as IoRead, Write as IoWrite};
use std::{any, cell::RefCell, cmp, sync::Arc, task::Context, task::Poll}; use std::{any, cell::RefCell, sync::Arc, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesMut, PoolRef}; use ntex_bytes::{BufMut, BytesMut};
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus}; use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::{future::poll_fn, ready}; use ntex_util::{future::poll_fn, ready};
use tls_rust::{ClientConfig, ClientConnection, ServerName}; use tls_rust::{ClientConfig, ClientConnection, ServerName};
use super::TlsFilter; use crate::rustls::{IoInner, TlsFilter, Wrapper};
use crate::types; use crate::types;
/// An implementation of SSL streams /// An implementation of SSL streams
@ -16,13 +16,6 @@ pub struct TlsClientFilter<F> {
session: RefCell<ClientConnection>, session: RefCell<ClientConnection>,
} }
struct IoInner<F> {
inner: F,
pool: PoolRef,
read_buf: Option<BytesMut>,
write_buf: Option<BytesMut>,
}
impl<F: Filter> Filter for TlsClientFilter<F> { impl<F: Filter> Filter for TlsClientFilter<F> {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> { fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2"; const H2: &[u8] = b"h2";
@ -42,85 +35,74 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
}; };
Some(Box::new(proto)) Some(Box::new(proto))
} else { } else {
self.inner.borrow().inner.query(id) self.inner.borrow().filter.query(id)
} }
} }
#[inline] #[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.borrow().inner.poll_shutdown() self.inner.borrow().filter.poll_shutdown()
} }
#[inline] #[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.borrow().inner.poll_read_ready(cx) self.inner.borrow().filter.poll_read_ready(cx)
} }
#[inline] #[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> { fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.borrow().inner.poll_write_ready(cx) self.inner.borrow().filter.poll_write_ready(cx)
} }
#[inline] #[inline]
fn get_read_buf(&self) -> Option<BytesMut> { fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().read_buf.take() { self.inner.borrow_mut().read_buf.take()
if !buf.is_empty() {
return Some(buf);
}
}
None
} }
#[inline] #[inline]
fn get_write_buf(&self) -> Option<BytesMut> { fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().write_buf.take() { self.inner.borrow_mut().write_buf.take()
if !buf.is_empty() {
return Some(buf);
}
}
None
} }
fn release_read_buf( #[inline]
&self, fn release_read_buf(&self, buf: BytesMut) {
io: &IoRef, self.inner.borrow_mut().read_buf = Some(buf);
src: BytesMut, }
dst: &mut Option<BytesMut>,
nbytes: usize, fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
) -> io::Result<usize> {
let mut inner = self.inner.borrow_mut(); let mut inner = self.inner.borrow_mut();
let mut session = self.session.borrow_mut(); let mut session = self.session.borrow_mut();
if session.is_handshaking() { // ask inner filter to process read buf
inner.read_buf = Some(src); match inner.filter.process_read_buf(io, nbytes) {
Ok(1) Err(err) => io.want_shutdown(Some(err)),
} else { Ok((_, 0)) => return Ok((0, 0)),
let mut src = { Ok(_) => (),
let mut dst = None;
if let Err(err) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
io.want_shutdown(Some(err));
} }
if let Some(dst) = dst { if session.is_handshaking() {
Ok((0, 1))
} else {
// get processed buffer
let mut dst = if let Some(dst) = inner.read_buf.take() {
dst dst
} else { } else {
return Ok(0); inner.pool.get_read_buf()
}
}; };
let (hw, lw) = inner.pool.read_params().unpack(); let (hw, lw) = inner.pool.read_params().unpack();
// get inner filter buffer let mut src = if let Some(src) = inner.filter.get_read_buf() {
if dst.is_none() { src
*dst = Some(inner.pool.get_read_buf()); } else {
} return Ok((0, 0));
let buf = dst.as_mut().unwrap(); };
let mut new_bytes = 0; let mut new_bytes = 0;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = dst.remaining_mut();
if remaining < lw { if remaining < lw {
buf.reserve(hw - remaining); dst.reserve(hw - remaining);
} }
let mut cursor = io::Cursor::new(&src); let mut cursor = io::Cursor::new(&src);
@ -132,21 +114,21 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
let new_b = state.plaintext_bytes_to_read(); let new_b = state.plaintext_bytes_to_read();
if new_b > 0 { if new_b > 0 {
buf.reserve(new_b); dst.reserve(new_b);
let chunk: &mut [u8] = let chunk: &mut [u8] =
unsafe { std::mem::transmute(&mut *buf.chunk_mut()) }; unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let v = session.reader().read(chunk)?; let v = session.reader().read(chunk)?;
unsafe { buf.advance_mut(v) }; unsafe { dst.advance_mut(v) };
new_bytes += v; new_bytes += v;
} else { } else {
break; break;
} }
} }
if !src.is_empty() { let dst_len = dst.len();
inner.read_buf = Some(src); inner.read_buf = Some(dst);
} inner.filter.release_read_buf(src);
Ok(new_bytes) Ok((dst_len, new_bytes))
} }
} }
@ -175,42 +157,6 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
} }
} }
struct Wrapper<'a, F>(&'a mut IoInner<F>);
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(read_buf) = self.0.read_buf.as_mut() {
let len = cmp::min(read_buf.len(), dst.len());
if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
}
}
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.0.inner.get_write_buf() {
buf.reserve(src.len());
buf
} else {
BytesMut::with_capacity_in(src.len(), self.0.pool)
};
buf.extend_from_slice(src);
self.0.inner.release_write_buf(buf)?;
Ok(src.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<F: Filter> TlsClientFilter<F> { impl<F: Filter> TlsClientFilter<F> {
pub(crate) async fn create( pub(crate) async fn create(
io: Io<F>, io: Io<F>,
@ -222,12 +168,11 @@ impl<F: Filter> TlsClientFilter<F> {
Ok(session) => session, Ok(session) => session,
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)), Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
}; };
let io = io.map_filter(|inner: F| { let io = io.map_filter(|filter: F| {
let read_buf = inner.get_read_buf();
let inner = IoInner { let inner = IoInner {
pool, pool,
inner, filter,
read_buf, read_buf: None,
write_buf: None, write_buf: None,
}; };

View file

@ -1,9 +1,9 @@
#![allow(clippy::type_complexity)] #![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL //! An implementation of SSL streams for ntex backed by OpenSSL
use std::sync::Arc; use std::sync::Arc;
use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll}; use std::{any, cmp, future::Future, io, pin::Pin, task::Context, task::Poll};
use ntex_bytes::BytesMut; use ntex_bytes::{BytesMut, PoolRef};
use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus}; use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::time::Millis; use ntex_util::time::Millis;
use tls_rust::{ClientConfig, ServerConfig, ServerName}; use tls_rust::{ClientConfig, ServerConfig, ServerName};
@ -101,16 +101,18 @@ impl<F: Filter> Filter for TlsFilter<F> {
} }
#[inline] #[inline]
fn release_read_buf( fn release_read_buf(&self, buf: BytesMut) {
&self,
io: &IoRef,
src: BytesMut,
dst: &mut Option<BytesMut>,
nb: usize,
) -> io::Result<usize> {
match self.inner { match self.inner {
InnerTlsFilter::Server(ref f) => f.release_read_buf(io, src, dst, nb), InnerTlsFilter::Server(ref f) => f.release_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.release_read_buf(io, src, dst, nb), InnerTlsFilter::Client(ref f) => f.release_read_buf(buf),
}
}
#[inline]
fn process_read_buf(&self, io: &IoRef, nb: usize) -> io::Result<(usize, usize)> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_read_buf(io, nb),
InnerTlsFilter::Client(ref f) => f.process_read_buf(io, nb),
} }
} }
@ -229,3 +231,48 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
Box::pin(async move { TlsClientFilter::create(st, cfg, server_name).await }) Box::pin(async move { TlsClientFilter::create(st, cfg, server_name).await })
} }
} }
pub(crate) struct IoInner<F> {
filter: F,
pool: PoolRef,
read_buf: Option<BytesMut>,
write_buf: Option<BytesMut>,
}
pub(crate) struct Wrapper<'a, F>(&'a mut IoInner<F>);
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut read_buf) = self.0.filter.get_read_buf() {
let len = cmp::min(read_buf.len(), dst.len());
let result = if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
};
self.0.filter.release_read_buf(read_buf);
result
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
}
}
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.0.filter.get_write_buf() {
buf.reserve(src.len());
buf
} else {
BytesMut::with_capacity_in(src.len(), self.0.pool)
};
buf.extend_from_slice(src);
self.0.filter.release_write_buf(buf)?;
Ok(src.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}

View file

@ -1,14 +1,14 @@
//! An implementation of SSL streams for ntex backed by OpenSSL //! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite}; use std::io::{self, Read as IoRead, Write as IoWrite};
use std::sync::Arc; use std::{any, cell::RefCell, sync::Arc, task::Context, task::Poll};
use std::{any, cell::RefCell, cmp, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesMut, PoolRef}; use ntex_bytes::{BufMut, BytesMut};
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus}; use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::{future::poll_fn, ready, time, time::Millis}; use ntex_util::{future::poll_fn, ready, time, time::Millis};
use tls_rust::{ServerConfig, ServerConnection}; use tls_rust::{ServerConfig, ServerConnection};
use crate::{rustls::TlsFilter, types}; use crate::rustls::{IoInner, TlsFilter, Wrapper};
use crate::types;
/// An implementation of SSL streams /// An implementation of SSL streams
pub struct TlsServerFilter<F> { pub struct TlsServerFilter<F> {
@ -16,13 +16,6 @@ pub struct TlsServerFilter<F> {
session: RefCell<ServerConnection>, session: RefCell<ServerConnection>,
} }
struct IoInner<F> {
inner: F,
pool: PoolRef,
read_buf: Option<BytesMut>,
write_buf: Option<BytesMut>,
}
impl<F: Filter> Filter for TlsServerFilter<F> { impl<F: Filter> Filter for TlsServerFilter<F> {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> { fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2"; const H2: &[u8] = b"h2";
@ -42,85 +35,74 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
}; };
Some(Box::new(proto)) Some(Box::new(proto))
} else { } else {
self.inner.borrow().inner.query(id) self.inner.borrow().filter.query(id)
} }
} }
#[inline] #[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.borrow().inner.poll_shutdown() self.inner.borrow().filter.poll_shutdown()
} }
#[inline] #[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.borrow().inner.poll_read_ready(cx) self.inner.borrow().filter.poll_read_ready(cx)
} }
#[inline] #[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> { fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.borrow().inner.poll_write_ready(cx) self.inner.borrow().filter.poll_write_ready(cx)
} }
#[inline] #[inline]
fn get_read_buf(&self) -> Option<BytesMut> { fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().read_buf.take() { self.inner.borrow_mut().read_buf.take()
if !buf.is_empty() {
return Some(buf);
}
}
None
} }
#[inline] #[inline]
fn get_write_buf(&self) -> Option<BytesMut> { fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().write_buf.take() { self.inner.borrow_mut().write_buf.take()
if !buf.is_empty() {
return Some(buf);
}
}
None
} }
fn release_read_buf( #[inline]
&self, fn release_read_buf(&self, buf: BytesMut) {
io: &IoRef, self.inner.borrow_mut().read_buf = Some(buf);
src: BytesMut, }
dst: &mut Option<BytesMut>,
nbytes: usize, fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
) -> io::Result<usize> {
let mut inner = self.inner.borrow_mut(); let mut inner = self.inner.borrow_mut();
let mut session = self.session.borrow_mut(); let mut session = self.session.borrow_mut();
if session.is_handshaking() { // ask inner filter to process read buf
inner.read_buf = Some(src); match inner.filter.process_read_buf(io, nbytes) {
Ok(1) Err(err) => io.want_shutdown(Some(err)),
} else { Ok((_, 0)) => return Ok((0, 0)),
let mut src = { Ok(_) => (),
let mut dst = None;
if let Err(e) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
io.want_shutdown(Some(e));
} }
if let Some(dst) = dst { if session.is_handshaking() {
Ok((0, 1))
} else {
// get processed buffer
let mut dst = if let Some(dst) = inner.read_buf.take() {
dst dst
} else { } else {
return Ok(0); inner.pool.get_read_buf()
}
}; };
let (hw, lw) = inner.pool.read_params().unpack(); let (hw, lw) = inner.pool.read_params().unpack();
// get inner filter buffer let mut src = if let Some(src) = inner.filter.get_read_buf() {
if dst.is_none() { src
*dst = Some(inner.pool.get_read_buf()); } else {
} return Ok((0, 0));
let buf = dst.as_mut().unwrap(); };
let mut new_bytes = 0; let mut new_bytes = 0;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = dst.remaining_mut();
if remaining < lw { if remaining < lw {
buf.reserve(hw - remaining); dst.reserve(hw - remaining);
} }
let mut cursor = io::Cursor::new(&src); let mut cursor = io::Cursor::new(&src);
@ -132,21 +114,21 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
let new_b = state.plaintext_bytes_to_read(); let new_b = state.plaintext_bytes_to_read();
if new_b > 0 { if new_b > 0 {
buf.reserve(new_b); dst.reserve(new_b);
let chunk: &mut [u8] = let chunk: &mut [u8] =
unsafe { std::mem::transmute(&mut *buf.chunk_mut()) }; unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let v = session.reader().read(chunk)?; let v = session.reader().read(chunk)?;
unsafe { buf.advance_mut(v) }; unsafe { dst.advance_mut(v) };
new_bytes += v; new_bytes += v;
} else { } else {
break; break;
} }
} }
if !src.is_empty() { let dst_len = dst.len();
inner.read_buf = Some(src); inner.read_buf = Some(dst);
} inner.filter.release_read_buf(src);
Ok(new_bytes) Ok((dst_len, new_bytes))
} }
} }
@ -174,42 +156,6 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
} }
} }
struct Wrapper<'a, F>(&'a mut IoInner<F>);
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(read_buf) = self.0.read_buf.as_mut() {
let len = cmp::min(read_buf.len(), dst.len());
if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len)
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
}
}
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.0.inner.get_write_buf() {
buf.reserve(src.len());
buf
} else {
BytesMut::with_capacity_in(src.len(), self.0.pool)
};
buf.extend_from_slice(src);
self.0.inner.release_write_buf(buf)?;
Ok(src.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<F: Filter> TlsServerFilter<F> { impl<F: Filter> TlsServerFilter<F> {
pub(crate) async fn create( pub(crate) async fn create(
io: Io<F>, io: Io<F>,
@ -222,12 +168,11 @@ impl<F: Filter> TlsServerFilter<F> {
Ok(session) => session, Ok(session) => session,
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)), Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
}; };
let io = io.map_filter(|inner: F| { let io = io.map_filter(|filter: F| {
let read_buf = inner.get_read_buf();
let inner = IoInner { let inner = IoInner {
pool, pool,
inner, filter,
read_buf, read_buf: None,
write_buf: None, write_buf: None,
}; };