mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-05 22:07:38 +03:00
Update Filter trait usage
This commit is contained in:
parent
1df005f53f
commit
b49a5ed195
6 changed files with 203 additions and 275 deletions
|
@ -1,5 +1,9 @@
|
|||
# Changes
|
||||
|
||||
## [0.1.2] - 2022-01-12
|
||||
|
||||
* Update Filter trait usage
|
||||
|
||||
## [0.1.1] - 2022-01-10
|
||||
|
||||
* Remove usage of ntex::io::Boxed types
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex-tls"
|
||||
version = "0.1.1"
|
||||
version = "0.1.2"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "An implementation of SSL streams for ntex backed by OpenSSL"
|
||||
keywords = ["network", "framework", "async", "futures"]
|
||||
|
@ -26,8 +26,8 @@ rustls = ["tls_rust"]
|
|||
|
||||
[dependencies]
|
||||
ntex-bytes = "0.1.9"
|
||||
ntex-io = "0.1.2"
|
||||
ntex-util = "0.1.8"
|
||||
ntex-io = "0.1.3"
|
||||
ntex-util = "0.1.9"
|
||||
ntex-service = "0.3.1"
|
||||
pin-project-lite = "0.2"
|
||||
|
||||
|
|
|
@ -27,25 +27,26 @@ pub struct PeerCertChain(pub Vec<X509>);
|
|||
/// An implementation of SSL streams
|
||||
pub struct SslFilter<F = Base> {
|
||||
inner: RefCell<SslStream<IoInner<F>>>,
|
||||
pool: PoolRef,
|
||||
handshake: Cell<bool>,
|
||||
read_buf: Cell<Option<BytesMut>>,
|
||||
}
|
||||
|
||||
struct IoInner<F> {
|
||||
inner: F,
|
||||
pool: PoolRef,
|
||||
read_buf: Option<BytesMut>,
|
||||
write_buf: Option<BytesMut>,
|
||||
}
|
||||
|
||||
impl<F: Filter> io::Read for IoInner<F> {
|
||||
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
||||
if let Some(ref mut buf) = self.read_buf {
|
||||
if let Some(mut buf) = self.inner.get_read_buf() {
|
||||
if buf.is_empty() {
|
||||
buf.clear();
|
||||
Err(io::Error::from(io::ErrorKind::WouldBlock))
|
||||
} else {
|
||||
let len = cmp::min(buf.len(), dst.len());
|
||||
dst[..len].copy_from_slice(&buf.split_to(len));
|
||||
self.inner.release_read_buf(buf);
|
||||
Ok(len)
|
||||
}
|
||||
} else {
|
||||
|
@ -139,70 +140,54 @@ impl<F: Filter> Filter for SslFilter<F> {
|
|||
|
||||
#[inline]
|
||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||
if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
}
|
||||
}
|
||||
None
|
||||
self.read_buf.take()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_write_buf(&self) -> Option<BytesMut> {
|
||||
if let Some(buf) = self.inner.borrow_mut().get_mut().write_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
}
|
||||
}
|
||||
None
|
||||
self.inner.borrow_mut().get_mut().write_buf.take()
|
||||
}
|
||||
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
io: &IoRef,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nbytes: usize,
|
||||
) -> io::Result<usize> {
|
||||
// store to read_buf
|
||||
let pool = {
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
let mut dst = None;
|
||||
let result = inner
|
||||
.get_ref()
|
||||
.inner
|
||||
.release_read_buf(io, src, &mut dst, nbytes);
|
||||
if let Err(err) = result {
|
||||
io.want_shutdown(Some(err));
|
||||
}
|
||||
if dst.is_some() {
|
||||
inner.get_mut().read_buf = dst;
|
||||
inner.get_ref().pool
|
||||
} else {
|
||||
return Ok(0);
|
||||
}
|
||||
};
|
||||
let (hw, lw) = pool.read_params().unpack();
|
||||
#[inline]
|
||||
fn release_read_buf(&self, buf: BytesMut) {
|
||||
self.read_buf.set(Some(buf));
|
||||
}
|
||||
|
||||
// get inner filter buffer
|
||||
if dst.is_none() {
|
||||
*dst = Some(pool.get_read_buf());
|
||||
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
|
||||
// ask inner filter to process read buf
|
||||
match self
|
||||
.inner
|
||||
.borrow_mut()
|
||||
.get_ref()
|
||||
.inner
|
||||
.process_read_buf(io, nbytes)
|
||||
{
|
||||
Err(err) => io.want_shutdown(Some(err)),
|
||||
Ok((n, 0)) => return Ok((n, 0)),
|
||||
Ok((_, _)) => (),
|
||||
}
|
||||
let buf = dst.as_mut().unwrap();
|
||||
|
||||
// get processed buffer
|
||||
let mut dst = if let Some(dst) = self.get_read_buf() {
|
||||
dst
|
||||
} else {
|
||||
self.pool.get_read_buf()
|
||||
};
|
||||
let (hw, lw) = self.pool.read_params().unpack();
|
||||
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
let remaining = dst.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
dst.reserve(hw - remaining);
|
||||
}
|
||||
|
||||
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
|
||||
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
|
||||
let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
|
||||
return match ssl_result {
|
||||
let result = match ssl_result {
|
||||
Ok(v) => {
|
||||
unsafe { buf.advance_mut(v) };
|
||||
unsafe { dst.advance_mut(v) };
|
||||
new_bytes += v;
|
||||
continue;
|
||||
}
|
||||
|
@ -216,14 +201,16 @@ impl<F: Filter> Filter for SslFilter<F> {
|
|||
self.handshake.set(false);
|
||||
}
|
||||
}
|
||||
Ok(new_bytes)
|
||||
Ok((dst.len(), new_bytes))
|
||||
}
|
||||
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
|
||||
io.want_shutdown(None);
|
||||
Ok(new_bytes)
|
||||
Ok((dst.len(), new_bytes))
|
||||
}
|
||||
Err(e) => Err(map_to_ioerr(e)),
|
||||
};
|
||||
self.release_read_buf(dst);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -299,18 +286,18 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
|
|||
let ssl = ctx_result.map_err(map_to_ioerr)?;
|
||||
let pool = st.memory_pool();
|
||||
let st = st.map_filter(|inner: F| {
|
||||
let read_buf = inner.get_read_buf();
|
||||
let inner = IoInner {
|
||||
pool,
|
||||
inner,
|
||||
read_buf,
|
||||
write_buf: None,
|
||||
};
|
||||
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
|
||||
|
||||
Ok::<_, Box<dyn Error>>(SslFilter {
|
||||
inner: RefCell::new(ssl_stream),
|
||||
pool,
|
||||
read_buf: Cell::new(None),
|
||||
handshake: Cell::new(true),
|
||||
inner: RefCell::new(ssl_stream),
|
||||
})
|
||||
})?;
|
||||
|
||||
|
@ -352,18 +339,18 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
|
|||
let ssl = self.ssl;
|
||||
let pool = st.memory_pool();
|
||||
let st = st.map_filter(|inner: F| {
|
||||
let read_buf = inner.get_read_buf();
|
||||
let inner = IoInner {
|
||||
pool,
|
||||
inner,
|
||||
read_buf,
|
||||
write_buf: None,
|
||||
};
|
||||
let ssl_stream = ssl::SslStream::new(ssl, inner)?;
|
||||
|
||||
Ok::<_, Box<dyn Error>>(SslFilter {
|
||||
inner: RefCell::new(ssl_stream),
|
||||
pool,
|
||||
read_buf: Cell::new(None),
|
||||
handshake: Cell::new(true),
|
||||
inner: RefCell::new(ssl_stream),
|
||||
})
|
||||
})?;
|
||||
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::io::{self, Read as IoRead, Write as IoWrite};
|
||||
use std::{any, cell::RefCell, cmp, sync::Arc, task::Context, task::Poll};
|
||||
use std::{any, cell::RefCell, sync::Arc, task::Context, task::Poll};
|
||||
|
||||
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
||||
use ntex_bytes::{BufMut, BytesMut};
|
||||
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
|
||||
use ntex_util::{future::poll_fn, ready};
|
||||
use tls_rust::{ClientConfig, ClientConnection, ServerName};
|
||||
|
||||
use super::TlsFilter;
|
||||
use crate::rustls::{IoInner, TlsFilter, Wrapper};
|
||||
use crate::types;
|
||||
|
||||
/// An implementation of SSL streams
|
||||
|
@ -16,13 +16,6 @@ pub struct TlsClientFilter<F> {
|
|||
session: RefCell<ClientConnection>,
|
||||
}
|
||||
|
||||
struct IoInner<F> {
|
||||
inner: F,
|
||||
pool: PoolRef,
|
||||
read_buf: Option<BytesMut>,
|
||||
write_buf: Option<BytesMut>,
|
||||
}
|
||||
|
||||
impl<F: Filter> Filter for TlsClientFilter<F> {
|
||||
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
|
||||
const H2: &[u8] = b"h2";
|
||||
|
@ -42,85 +35,74 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
|
|||
};
|
||||
Some(Box::new(proto))
|
||||
} else {
|
||||
self.inner.borrow().inner.query(id)
|
||||
self.inner.borrow().filter.query(id)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
|
||||
self.inner.borrow().inner.poll_shutdown()
|
||||
self.inner.borrow().filter.poll_shutdown()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
|
||||
self.inner.borrow().inner.poll_read_ready(cx)
|
||||
self.inner.borrow().filter.poll_read_ready(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
|
||||
self.inner.borrow().inner.poll_write_ready(cx)
|
||||
self.inner.borrow().filter.poll_write_ready(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
}
|
||||
}
|
||||
None
|
||||
self.inner.borrow_mut().read_buf.take()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_write_buf(&self) -> Option<BytesMut> {
|
||||
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
}
|
||||
}
|
||||
None
|
||||
self.inner.borrow_mut().write_buf.take()
|
||||
}
|
||||
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
io: &IoRef,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nbytes: usize,
|
||||
) -> io::Result<usize> {
|
||||
#[inline]
|
||||
fn release_read_buf(&self, buf: BytesMut) {
|
||||
self.inner.borrow_mut().read_buf = Some(buf);
|
||||
}
|
||||
|
||||
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
let mut session = self.session.borrow_mut();
|
||||
|
||||
if session.is_handshaking() {
|
||||
inner.read_buf = Some(src);
|
||||
Ok(1)
|
||||
} else {
|
||||
let mut src = {
|
||||
let mut dst = None;
|
||||
if let Err(err) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
|
||||
io.want_shutdown(Some(err));
|
||||
}
|
||||
// ask inner filter to process read buf
|
||||
match inner.filter.process_read_buf(io, nbytes) {
|
||||
Err(err) => io.want_shutdown(Some(err)),
|
||||
Ok((_, 0)) => return Ok((0, 0)),
|
||||
Ok(_) => (),
|
||||
}
|
||||
|
||||
if let Some(dst) = dst {
|
||||
dst
|
||||
} else {
|
||||
return Ok(0);
|
||||
}
|
||||
if session.is_handshaking() {
|
||||
Ok((0, 1))
|
||||
} else {
|
||||
// get processed buffer
|
||||
let mut dst = if let Some(dst) = inner.read_buf.take() {
|
||||
dst
|
||||
} else {
|
||||
inner.pool.get_read_buf()
|
||||
};
|
||||
let (hw, lw) = inner.pool.read_params().unpack();
|
||||
|
||||
// get inner filter buffer
|
||||
if dst.is_none() {
|
||||
*dst = Some(inner.pool.get_read_buf());
|
||||
}
|
||||
let buf = dst.as_mut().unwrap();
|
||||
let mut src = if let Some(src) = inner.filter.get_read_buf() {
|
||||
src
|
||||
} else {
|
||||
return Ok((0, 0));
|
||||
};
|
||||
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
let remaining = dst.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
dst.reserve(hw - remaining);
|
||||
}
|
||||
|
||||
let mut cursor = io::Cursor::new(&src);
|
||||
|
@ -132,21 +114,21 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
|
|||
|
||||
let new_b = state.plaintext_bytes_to_read();
|
||||
if new_b > 0 {
|
||||
buf.reserve(new_b);
|
||||
dst.reserve(new_b);
|
||||
let chunk: &mut [u8] =
|
||||
unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
|
||||
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
|
||||
let v = session.reader().read(chunk)?;
|
||||
unsafe { buf.advance_mut(v) };
|
||||
unsafe { dst.advance_mut(v) };
|
||||
new_bytes += v;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !src.is_empty() {
|
||||
inner.read_buf = Some(src);
|
||||
}
|
||||
Ok(new_bytes)
|
||||
let dst_len = dst.len();
|
||||
inner.read_buf = Some(dst);
|
||||
inner.filter.release_read_buf(src);
|
||||
Ok((dst_len, new_bytes))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -175,42 +157,6 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
|
|||
}
|
||||
}
|
||||
|
||||
struct Wrapper<'a, F>(&'a mut IoInner<F>);
|
||||
|
||||
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
|
||||
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
||||
if let Some(read_buf) = self.0.read_buf.as_mut() {
|
||||
let len = cmp::min(read_buf.len(), dst.len());
|
||||
if len > 0 {
|
||||
dst[..len].copy_from_slice(&read_buf.split_to(len));
|
||||
Ok(len)
|
||||
} else {
|
||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
||||
}
|
||||
} else {
|
||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
|
||||
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
|
||||
let mut buf = if let Some(mut buf) = self.0.inner.get_write_buf() {
|
||||
buf.reserve(src.len());
|
||||
buf
|
||||
} else {
|
||||
BytesMut::with_capacity_in(src.len(), self.0.pool)
|
||||
};
|
||||
buf.extend_from_slice(src);
|
||||
self.0.inner.release_write_buf(buf)?;
|
||||
Ok(src.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter> TlsClientFilter<F> {
|
||||
pub(crate) async fn create(
|
||||
io: Io<F>,
|
||||
|
@ -222,12 +168,11 @@ impl<F: Filter> TlsClientFilter<F> {
|
|||
Ok(session) => session,
|
||||
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
||||
};
|
||||
let io = io.map_filter(|inner: F| {
|
||||
let read_buf = inner.get_read_buf();
|
||||
let io = io.map_filter(|filter: F| {
|
||||
let inner = IoInner {
|
||||
pool,
|
||||
inner,
|
||||
read_buf,
|
||||
filter,
|
||||
read_buf: None,
|
||||
write_buf: None,
|
||||
};
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
#![allow(clippy::type_complexity)]
|
||||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::sync::Arc;
|
||||
use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll};
|
||||
use std::{any, cmp, future::Future, io, pin::Pin, task::Context, task::Poll};
|
||||
|
||||
use ntex_bytes::BytesMut;
|
||||
use ntex_bytes::{BytesMut, PoolRef};
|
||||
use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
|
||||
use ntex_util::time::Millis;
|
||||
use tls_rust::{ClientConfig, ServerConfig, ServerName};
|
||||
|
@ -101,16 +101,18 @@ impl<F: Filter> Filter for TlsFilter<F> {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
io: &IoRef,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nb: usize,
|
||||
) -> io::Result<usize> {
|
||||
fn release_read_buf(&self, buf: BytesMut) {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref f) => f.release_read_buf(io, src, dst, nb),
|
||||
InnerTlsFilter::Client(ref f) => f.release_read_buf(io, src, dst, nb),
|
||||
InnerTlsFilter::Server(ref f) => f.release_read_buf(buf),
|
||||
InnerTlsFilter::Client(ref f) => f.release_read_buf(buf),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn process_read_buf(&self, io: &IoRef, nb: usize) -> io::Result<(usize, usize)> {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref f) => f.process_read_buf(io, nb),
|
||||
InnerTlsFilter::Client(ref f) => f.process_read_buf(io, nb),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -229,3 +231,48 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
|
|||
Box::pin(async move { TlsClientFilter::create(st, cfg, server_name).await })
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct IoInner<F> {
|
||||
filter: F,
|
||||
pool: PoolRef,
|
||||
read_buf: Option<BytesMut>,
|
||||
write_buf: Option<BytesMut>,
|
||||
}
|
||||
|
||||
pub(crate) struct Wrapper<'a, F>(&'a mut IoInner<F>);
|
||||
|
||||
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
|
||||
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
||||
if let Some(mut read_buf) = self.0.filter.get_read_buf() {
|
||||
let len = cmp::min(read_buf.len(), dst.len());
|
||||
let result = if len > 0 {
|
||||
dst[..len].copy_from_slice(&read_buf.split_to(len));
|
||||
Ok(len)
|
||||
} else {
|
||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
||||
};
|
||||
self.0.filter.release_read_buf(read_buf);
|
||||
result
|
||||
} else {
|
||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
|
||||
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
|
||||
let mut buf = if let Some(mut buf) = self.0.filter.get_write_buf() {
|
||||
buf.reserve(src.len());
|
||||
buf
|
||||
} else {
|
||||
BytesMut::with_capacity_in(src.len(), self.0.pool)
|
||||
};
|
||||
buf.extend_from_slice(src);
|
||||
self.0.filter.release_write_buf(buf)?;
|
||||
Ok(src.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::io::{self, Read as IoRead, Write as IoWrite};
|
||||
use std::sync::Arc;
|
||||
use std::{any, cell::RefCell, cmp, task::Context, task::Poll};
|
||||
use std::{any, cell::RefCell, sync::Arc, task::Context, task::Poll};
|
||||
|
||||
use ntex_bytes::{BufMut, BytesMut, PoolRef};
|
||||
use ntex_bytes::{BufMut, BytesMut};
|
||||
use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
|
||||
use ntex_util::{future::poll_fn, ready, time, time::Millis};
|
||||
use tls_rust::{ServerConfig, ServerConnection};
|
||||
|
||||
use crate::{rustls::TlsFilter, types};
|
||||
use crate::rustls::{IoInner, TlsFilter, Wrapper};
|
||||
use crate::types;
|
||||
|
||||
/// An implementation of SSL streams
|
||||
pub struct TlsServerFilter<F> {
|
||||
|
@ -16,13 +16,6 @@ pub struct TlsServerFilter<F> {
|
|||
session: RefCell<ServerConnection>,
|
||||
}
|
||||
|
||||
struct IoInner<F> {
|
||||
inner: F,
|
||||
pool: PoolRef,
|
||||
read_buf: Option<BytesMut>,
|
||||
write_buf: Option<BytesMut>,
|
||||
}
|
||||
|
||||
impl<F: Filter> Filter for TlsServerFilter<F> {
|
||||
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
|
||||
const H2: &[u8] = b"h2";
|
||||
|
@ -42,85 +35,74 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
|||
};
|
||||
Some(Box::new(proto))
|
||||
} else {
|
||||
self.inner.borrow().inner.query(id)
|
||||
self.inner.borrow().filter.query(id)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
|
||||
self.inner.borrow().inner.poll_shutdown()
|
||||
self.inner.borrow().filter.poll_shutdown()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
|
||||
self.inner.borrow().inner.poll_read_ready(cx)
|
||||
self.inner.borrow().filter.poll_read_ready(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
|
||||
self.inner.borrow().inner.poll_write_ready(cx)
|
||||
self.inner.borrow().filter.poll_write_ready(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
}
|
||||
}
|
||||
None
|
||||
self.inner.borrow_mut().read_buf.take()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_write_buf(&self) -> Option<BytesMut> {
|
||||
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
}
|
||||
}
|
||||
None
|
||||
self.inner.borrow_mut().write_buf.take()
|
||||
}
|
||||
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
io: &IoRef,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nbytes: usize,
|
||||
) -> io::Result<usize> {
|
||||
#[inline]
|
||||
fn release_read_buf(&self, buf: BytesMut) {
|
||||
self.inner.borrow_mut().read_buf = Some(buf);
|
||||
}
|
||||
|
||||
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
let mut session = self.session.borrow_mut();
|
||||
|
||||
if session.is_handshaking() {
|
||||
inner.read_buf = Some(src);
|
||||
Ok(1)
|
||||
} else {
|
||||
let mut src = {
|
||||
let mut dst = None;
|
||||
if let Err(e) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
|
||||
io.want_shutdown(Some(e));
|
||||
}
|
||||
// ask inner filter to process read buf
|
||||
match inner.filter.process_read_buf(io, nbytes) {
|
||||
Err(err) => io.want_shutdown(Some(err)),
|
||||
Ok((_, 0)) => return Ok((0, 0)),
|
||||
Ok(_) => (),
|
||||
}
|
||||
|
||||
if let Some(dst) = dst {
|
||||
dst
|
||||
} else {
|
||||
return Ok(0);
|
||||
}
|
||||
if session.is_handshaking() {
|
||||
Ok((0, 1))
|
||||
} else {
|
||||
// get processed buffer
|
||||
let mut dst = if let Some(dst) = inner.read_buf.take() {
|
||||
dst
|
||||
} else {
|
||||
inner.pool.get_read_buf()
|
||||
};
|
||||
let (hw, lw) = inner.pool.read_params().unpack();
|
||||
|
||||
// get inner filter buffer
|
||||
if dst.is_none() {
|
||||
*dst = Some(inner.pool.get_read_buf());
|
||||
}
|
||||
let buf = dst.as_mut().unwrap();
|
||||
let mut src = if let Some(src) = inner.filter.get_read_buf() {
|
||||
src
|
||||
} else {
|
||||
return Ok((0, 0));
|
||||
};
|
||||
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
let remaining = dst.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
dst.reserve(hw - remaining);
|
||||
}
|
||||
|
||||
let mut cursor = io::Cursor::new(&src);
|
||||
|
@ -132,21 +114,21 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
|||
|
||||
let new_b = state.plaintext_bytes_to_read();
|
||||
if new_b > 0 {
|
||||
buf.reserve(new_b);
|
||||
dst.reserve(new_b);
|
||||
let chunk: &mut [u8] =
|
||||
unsafe { std::mem::transmute(&mut *buf.chunk_mut()) };
|
||||
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
|
||||
let v = session.reader().read(chunk)?;
|
||||
unsafe { buf.advance_mut(v) };
|
||||
unsafe { dst.advance_mut(v) };
|
||||
new_bytes += v;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !src.is_empty() {
|
||||
inner.read_buf = Some(src);
|
||||
}
|
||||
Ok(new_bytes)
|
||||
let dst_len = dst.len();
|
||||
inner.read_buf = Some(dst);
|
||||
inner.filter.release_read_buf(src);
|
||||
Ok((dst_len, new_bytes))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -174,42 +156,6 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
|||
}
|
||||
}
|
||||
|
||||
struct Wrapper<'a, F>(&'a mut IoInner<F>);
|
||||
|
||||
impl<'a, F: Filter> io::Read for Wrapper<'a, F> {
|
||||
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
|
||||
if let Some(read_buf) = self.0.read_buf.as_mut() {
|
||||
let len = cmp::min(read_buf.len(), dst.len());
|
||||
if len > 0 {
|
||||
dst[..len].copy_from_slice(&read_buf.split_to(len));
|
||||
Ok(len)
|
||||
} else {
|
||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
||||
}
|
||||
} else {
|
||||
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F: Filter> io::Write for Wrapper<'a, F> {
|
||||
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
|
||||
let mut buf = if let Some(mut buf) = self.0.inner.get_write_buf() {
|
||||
buf.reserve(src.len());
|
||||
buf
|
||||
} else {
|
||||
BytesMut::with_capacity_in(src.len(), self.0.pool)
|
||||
};
|
||||
buf.extend_from_slice(src);
|
||||
self.0.inner.release_write_buf(buf)?;
|
||||
Ok(src.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter> TlsServerFilter<F> {
|
||||
pub(crate) async fn create(
|
||||
io: Io<F>,
|
||||
|
@ -222,12 +168,11 @@ impl<F: Filter> TlsServerFilter<F> {
|
|||
Ok(session) => session,
|
||||
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)),
|
||||
};
|
||||
let io = io.map_filter(|inner: F| {
|
||||
let read_buf = inner.get_read_buf();
|
||||
let io = io.map_filter(|filter: F| {
|
||||
let inner = IoInner {
|
||||
pool,
|
||||
inner,
|
||||
read_buf,
|
||||
filter,
|
||||
read_buf: None,
|
||||
write_buf: None,
|
||||
};
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue