update tls filters

This commit is contained in:
Nikolay Kim 2021-12-19 10:55:42 +06:00
parent 1ccb87ea51
commit d1071a1f18
5 changed files with 121 additions and 149 deletions

View file

@ -6,9 +6,7 @@ use std::{
};
use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_io::{
Filter, FilterFactory, Io, IoRef, ReadFilter, WriteFilter, WriteReadiness,
};
use ntex_io::{Filter, FilterFactory, Io, IoRef, WriteReadiness};
use ntex_util::{future::poll_fn, time, time::Millis};
use tls_openssl::ssl::{self, SslStream};
@ -103,17 +101,26 @@ impl<F: Filter> Filter for SslFilter<F> {
self.inner.borrow().get_ref().inner.query(id)
}
}
}
impl<F: Filter> ReadFilter for SslFilter<F> {
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.inner.borrow().get_ref().inner.poll_read_ready(cx)
}
fn read_closed(&self, err: Option<io::Error>) {
self.inner.borrow().get_ref().inner.read_closed(err)
#[inline]
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
self.inner.borrow().get_ref().inner.poll_write_ready(cx)
}
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.inner.borrow().get_ref().inner.closed(err)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() {
if !buf.is_empty() {
@ -123,19 +130,25 @@ impl<F: Filter> ReadFilter for SslFilter<F> {
None
}
fn release_read_buf(
&self,
src: BytesMut,
new_bytes: usize,
) -> Result<bool, io::Error> {
#[inline]
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
}
fn release_read_buf(&self, src: BytesMut, nbytes: usize) -> Result<(), io::Error> {
// store to read_buf
let pool = {
let mut inner = self.inner.borrow_mut();
inner.get_mut().read_buf = Some(src);
inner.get_ref().pool
};
if new_bytes == 0 {
return Ok(false);
if nbytes == 0 {
return Ok(());
}
let (hw, lw) = pool.read_params().unpack();
@ -177,30 +190,8 @@ impl<F: Filter> ReadFilter for SslFilter<F> {
};
}
}
}
impl<F: Filter> WriteFilter for SslFilter<F> {
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
self.inner.borrow().get_ref().inner.poll_write_ready(cx)
}
fn write_closed(&self, err: Option<io::Error>) {
self.inner.borrow().get_ref().inner.read_closed(err)
}
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
}
fn release_write_buf(&self, mut buf: BytesMut) -> Result<bool, io::Error> {
fn release_write_buf(&self, mut buf: BytesMut) -> Result<(), io::Error> {
let ssl_result = self.inner.borrow_mut().ssl_write(&buf);
let result = match ssl_result {
Ok(v) => {
@ -208,10 +199,10 @@ impl<F: Filter> WriteFilter for SslFilter<F> {
buf.split_to(v);
self.inner.borrow_mut().get_mut().write_buf = Some(buf);
}
Ok(false)
Ok(())
}
Err(e) => match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(false),
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()),
_ => Err(map_to_ioerr(e)),
},
};

View file

@ -4,7 +4,7 @@ use std::sync::Arc;
use std::{any, cell::RefCell, cmp, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_io::{Filter, Io, IoRef, ReadFilter, WriteFilter, WriteReadiness};
use ntex_io::{Filter, Io, IoRef, WriteReadiness};
use ntex_util::future::poll_fn;
use tls_rust::{ClientConfig, ClientConnection, ServerName};
@ -25,6 +25,7 @@ struct IoInner<F> {
}
impl<F: Filter> Filter for TlsClientFilter<F> {
#[inline]
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>> {
self.inner.borrow().inner.shutdown(st)
}
@ -50,17 +51,26 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
self.inner.borrow().inner.query(id)
}
}
}
impl<F: Filter> ReadFilter for TlsClientFilter<F> {
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.inner.borrow().inner.poll_read_ready(cx)
}
fn read_closed(&self, err: Option<io::Error>) {
self.inner.borrow().inner.read_closed(err)
#[inline]
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
self.inner.borrow().inner.poll_write_ready(cx)
}
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.inner.borrow().inner.closed(err)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
if !buf.is_empty() {
@ -70,18 +80,24 @@ impl<F: Filter> ReadFilter for TlsClientFilter<F> {
None
}
fn release_read_buf(
&self,
mut src: BytesMut,
_nb: usize,
) -> Result<bool, io::Error> {
#[inline]
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
}
fn release_read_buf(&self, mut src: BytesMut, _nb: usize) -> Result<(), io::Error> {
let mut session = self.session.borrow_mut();
if session.is_handshaking() {
self.inner.borrow_mut().read_buf = Some(src);
Ok(false)
Ok(())
} else {
if src.is_empty() {
return Ok(false);
return Ok(());
}
let mut inner = self.inner.borrow_mut();
let (hw, lw) = inner.pool.read_params().unpack();
@ -127,30 +143,8 @@ impl<F: Filter> ReadFilter for TlsClientFilter<F> {
inner.inner.release_read_buf(buf, new_bytes)
}
}
}
impl<F: Filter> WriteFilter for TlsClientFilter<F> {
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
self.inner.borrow().inner.poll_write_ready(cx)
}
fn write_closed(&self, err: Option<io::Error>) {
self.inner.borrow().inner.read_closed(err)
}
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
}
fn release_write_buf(&self, mut src: BytesMut) -> Result<bool, io::Error> {
fn release_write_buf(&self, mut src: BytesMut) -> Result<(), io::Error> {
let mut session = self.session.borrow_mut();
let mut inner = self.inner.borrow_mut();
let mut io = Wrapper(&mut *inner);
@ -171,7 +165,7 @@ impl<F: Filter> WriteFilter for TlsClientFilter<F> {
self.inner.borrow_mut().write_buf = Some(src);
}
Ok(false)
Ok(())
}
}

View file

@ -4,9 +4,7 @@ use std::sync::Arc;
use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll};
use ntex_bytes::BytesMut;
use ntex_io::{
Filter, FilterFactory, Io, IoRef, ReadFilter, WriteFilter, WriteReadiness,
};
use ntex_io::{Filter, FilterFactory, Io, IoRef, WriteReadiness};
use ntex_util::time::Millis;
use tls_rust::{ClientConfig, ServerConfig, ServerName};
@ -69,9 +67,16 @@ impl<F: Filter> Filter for TlsFilter<F> {
InnerTlsFilter::Client(ref f) => f.query(id),
}
}
}
impl<F: Filter> ReadFilter for TlsFilter<F> {
#[inline]
fn closed(&self, err: Option<io::Error>) {
match self.inner {
InnerTlsFilter::Server(ref f) => f.closed(err),
InnerTlsFilter::Client(ref f) => f.closed(err),
}
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.poll_read_ready(cx),
@ -79,29 +84,7 @@ impl<F: Filter> ReadFilter for TlsFilter<F> {
}
}
fn read_closed(&self, err: Option<io::Error>) {
match self.inner {
InnerTlsFilter::Server(ref f) => f.read_closed(err),
InnerTlsFilter::Client(ref f) => f.read_closed(err),
}
}
fn get_read_buf(&self) -> Option<BytesMut> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.get_read_buf(),
InnerTlsFilter::Client(ref f) => f.get_read_buf(),
}
}
fn release_read_buf(&self, src: BytesMut, nb: usize) -> Result<bool, io::Error> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_read_buf(src, nb),
InnerTlsFilter::Client(ref f) => f.release_read_buf(src, nb),
}
}
}
impl<F: Filter> WriteFilter for TlsFilter<F> {
#[inline]
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
@ -112,13 +95,15 @@ impl<F: Filter> WriteFilter for TlsFilter<F> {
}
}
fn write_closed(&self, err: Option<io::Error>) {
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.write_closed(err),
InnerTlsFilter::Client(ref f) => f.write_closed(err),
InnerTlsFilter::Server(ref f) => f.get_read_buf(),
InnerTlsFilter::Client(ref f) => f.get_read_buf(),
}
}
#[inline]
fn get_write_buf(&self) -> Option<BytesMut> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.get_write_buf(),
@ -126,7 +111,16 @@ impl<F: Filter> WriteFilter for TlsFilter<F> {
}
}
fn release_write_buf(&self, src: BytesMut) -> Result<bool, io::Error> {
#[inline]
fn release_read_buf(&self, src: BytesMut, nb: usize) -> Result<(), io::Error> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_read_buf(src, nb),
InnerTlsFilter::Client(ref f) => f.release_read_buf(src, nb),
}
}
#[inline]
fn release_write_buf(&self, src: BytesMut) -> Result<(), io::Error> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_write_buf(src),
InnerTlsFilter::Client(ref f) => f.release_write_buf(src),

View file

@ -4,12 +4,11 @@ use std::sync::Arc;
use std::{any, cell::RefCell, cmp, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_io::{Filter, Io, IoRef, ReadFilter, WriteFilter, WriteReadiness};
use ntex_io::{Filter, Io, IoRef, WriteReadiness};
use ntex_util::{future::poll_fn, time, time::Millis};
use tls_rust::{ServerConfig, ServerConnection};
use super::TlsFilter;
use crate::types;
use crate::{rustls::TlsFilter, types};
/// An implementation of SSL streams
pub struct TlsServerFilter<F> {
@ -25,6 +24,7 @@ struct IoInner<F> {
}
impl<F: Filter> Filter for TlsServerFilter<F> {
#[inline]
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>> {
self.inner.borrow().inner.shutdown(st)
}
@ -50,17 +50,26 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
self.inner.borrow().inner.query(id)
}
}
}
impl<F: Filter> ReadFilter for TlsServerFilter<F> {
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.inner.borrow().inner.closed(err)
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.inner.borrow().inner.poll_read_ready(cx)
}
fn read_closed(&self, err: Option<io::Error>) {
self.inner.borrow().inner.read_closed(err)
#[inline]
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
self.inner.borrow().inner.poll_write_ready(cx)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
if !buf.is_empty() {
@ -70,18 +79,24 @@ impl<F: Filter> ReadFilter for TlsServerFilter<F> {
None
}
fn release_read_buf(
&self,
mut src: BytesMut,
_nb: usize,
) -> Result<bool, io::Error> {
#[inline]
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
}
fn release_read_buf(&self, mut src: BytesMut, _nb: usize) -> Result<(), io::Error> {
let mut session = self.session.borrow_mut();
if session.is_handshaking() {
self.inner.borrow_mut().read_buf = Some(src);
Ok(false)
Ok(())
} else {
if src.is_empty() {
return Ok(false);
return Ok(());
}
let mut inner = self.inner.borrow_mut();
let (hw, lw) = inner.pool.read_params().unpack();
@ -127,30 +142,8 @@ impl<F: Filter> ReadFilter for TlsServerFilter<F> {
inner.inner.release_read_buf(buf, new_bytes)
}
}
}
impl<F: Filter> WriteFilter for TlsServerFilter<F> {
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
self.inner.borrow().inner.poll_write_ready(cx)
}
fn write_closed(&self, err: Option<io::Error>) {
self.inner.borrow().inner.read_closed(err)
}
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().write_buf.take() {
if !buf.is_empty() {
return Some(buf);
}
}
None
}
fn release_write_buf(&self, mut src: BytesMut) -> Result<bool, io::Error> {
fn release_write_buf(&self, mut src: BytesMut) -> Result<(), io::Error> {
let mut session = self.session.borrow_mut();
let mut inner = self.inner.borrow_mut();
let mut io = Wrapper(&mut *inner);
@ -171,7 +164,7 @@ impl<F: Filter> WriteFilter for TlsServerFilter<F> {
self.inner.borrow_mut().write_buf = Some(src);
}
Ok(false)
Ok(())
}
}

View file

@ -230,10 +230,10 @@ impl<T> SendError<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::{util::lazy, util::next, Stream};
use crate::{future::lazy, future::next, Stream};
use futures_sink::Sink;
#[crate::rt_test]
#[ntex_macros::rt_test2]
async fn test_mpsc() {
let (tx, mut rx) = channel();
assert!(format!("{:?}", tx).contains("Sender"));
@ -282,7 +282,7 @@ mod tests {
assert_eq!(err.into_inner(), "test");
}
#[crate::rt_test]
#[ntex_macros::rt_test2]
async fn test_sink() {
let (mut tx, mut rx) = channel();
lazy(|cx| {
@ -296,7 +296,7 @@ mod tests {
assert_eq!(next(&mut rx).await, None);
}
#[crate::rt_test]
#[ntex_macros::rt_test2]
async fn test_close() {
let (tx, rx) = channel::<()>();
assert!(!tx.is_closed());