#![allow(clippy::type_complexity)] //! An implementation of SSL streams for ntex backed by OpenSSL use std::cell::RefCell; use std::{ any, cmp, error::Error, future::Future, io, pin::Pin, task::Context, task::Poll, }; use ntex_bytes::{BufMut, BytesMut, PoolRef}; use ntex_io::{Base, Filter, FilterFactory, Io, ReadStatus, WriteStatus}; use ntex_util::{future::poll_fn, ready, time, time::Millis}; use tls_openssl::ssl::{self, SslStream}; mod accept; pub use self::accept::{Acceptor, AcceptorService}; use super::types; /// An implementation of SSL streams pub struct SslFilter { inner: RefCell>>, } struct IoInner { inner: F, pool: PoolRef, read_buf: Option, write_buf: Option, } impl io::Read for IoInner { fn read(&mut self, dst: &mut [u8]) -> io::Result { if let Some(ref mut buf) = self.read_buf { if buf.is_empty() { buf.clear(); Err(io::Error::from(io::ErrorKind::WouldBlock)) } else { let len = cmp::min(buf.len(), dst.len()); dst[..len].copy_from_slice(&buf.split_to(len)); Ok(len) } } else { Err(io::Error::from(io::ErrorKind::WouldBlock)) } } } impl io::Write for IoInner { fn write(&mut self, src: &[u8]) -> io::Result { let mut buf = if let Some(mut buf) = self.inner.get_write_buf() { buf.reserve(src.len()); buf } else { BytesMut::with_capacity_in(src.len(), self.pool) }; buf.extend_from_slice(src); self.inner.release_write_buf(buf)?; Ok(src.len()) } fn flush(&mut self) -> io::Result<()> { Ok(()) } } impl Filter for SslFilter { fn query(&self, id: any::TypeId) -> Option> { if id == any::TypeId::of::() { let proto = if let Some(protos) = self.inner.borrow().ssl().selected_alpn_protocol() { if protos.windows(2).any(|window| window == b"h2") { types::HttpProtocol::Http2 } else { types::HttpProtocol::Http1 } } else { types::HttpProtocol::Http1 }; Some(Box::new(proto)) } else { self.inner.borrow().get_ref().inner.query(id) } } fn poll_shutdown(&self) -> Poll> { let ssl_result = self.inner.borrow_mut().shutdown(); match ssl_result { Ok(ssl::ShutdownResult::Sent) => Poll::Pending, Ok(ssl::ShutdownResult::Received) => { self.inner.borrow().get_ref().inner.poll_shutdown() } Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Poll::Ready(Ok(())), Err(ref e) if e.code() == ssl::ErrorCode::WANT_READ || e.code() == ssl::ErrorCode::WANT_WRITE => { Poll::Pending } Err(e) => Poll::Ready(Err(e .into_io_error() .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))), } } #[inline] fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { self.inner.borrow().get_ref().inner.poll_read_ready(cx) } #[inline] fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { self.inner.borrow().get_ref().inner.poll_write_ready(cx) } #[inline] fn closed(&self, err: Option) { self.inner.borrow().get_ref().inner.closed(err) } #[inline] fn want_read(&self) { self.inner.borrow().get_ref().inner.want_read() } #[inline] fn want_shutdown(&self) { self.inner.borrow().get_ref().inner.want_shutdown() } #[inline] fn get_read_buf(&self) -> Option { if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() { if !buf.is_empty() { return Some(buf); } } None } #[inline] fn get_write_buf(&self) -> Option { if let Some(buf) = self.inner.borrow_mut().get_mut().write_buf.take() { if !buf.is_empty() { return Some(buf); } } None } fn release_read_buf(&self, src: BytesMut, nbytes: usize) -> Result<(), io::Error> { // store to read_buf let pool = { let mut inner = self.inner.borrow_mut(); inner.get_mut().read_buf = Some(src); inner.get_ref().pool }; if nbytes == 0 { return Ok(()); } let (hw, lw) = pool.read_params().unpack(); // get inner filter buffer let mut buf = if let Some(buf) = self.inner.borrow().get_ref().inner.get_read_buf() { buf } else { BytesMut::with_capacity_in(lw, pool) }; let mut new_bytes = 0; loop { // make sure we've got room let remaining = buf.remaining_mut(); if remaining < lw { buf.reserve(hw - remaining); } let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *buf.chunk_mut()) }; let ssl_result = self.inner.borrow_mut().ssl_read(chunk); return match ssl_result { Ok(v) => { unsafe { buf.advance_mut(v) }; new_bytes += v; continue; } Err(ref e) if e.code() == ssl::ErrorCode::WANT_READ || e.code() == ssl::ErrorCode::WANT_WRITE => { self.inner .borrow() .get_ref() .inner .release_read_buf(buf, new_bytes) } Err(e) => Err(map_to_ioerr(e)), }; } } fn release_write_buf(&self, mut buf: BytesMut) -> Result<(), io::Error> { let ssl_result = self.inner.borrow_mut().ssl_write(&buf); let result = match ssl_result { Ok(v) => { if v != buf.len() { buf.split_to(v); self.inner.borrow_mut().get_mut().write_buf = Some(buf); } Ok(()) } Err(e) => match e.code() { ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()), _ => Err(map_to_ioerr(e)), }, }; result } } pub struct SslAcceptor { acceptor: ssl::SslAcceptor, timeout: Millis, } impl SslAcceptor { /// Create openssl acceptor filter factory pub fn new(acceptor: ssl::SslAcceptor) -> Self { SslAcceptor { acceptor, timeout: Millis(5_000), } } /// Set handshake timeout. /// /// Default is set to 5 seconds. pub fn timeout>(&mut self, timeout: U) -> &mut Self { self.timeout = timeout.into(); self } } impl Clone for SslAcceptor { fn clone(&self) -> Self { Self { acceptor: self.acceptor.clone(), timeout: self.timeout, } } } impl FilterFactory for SslAcceptor { type Filter = SslFilter; type Error = Box; type Future = Pin, Self::Error>>>>; fn create(self, st: Io) -> Self::Future { let timeout = self.timeout; let ctx_result = ssl::Ssl::new(self.acceptor.context()); Box::pin(async move { time::timeout(timeout, async { let ssl = ctx_result.map_err(map_to_ioerr)?; let pool = st.memory_pool(); let st = st.map_filter(|inner: F| { let read_buf = inner.get_read_buf(); let inner = IoInner { pool, inner, read_buf, write_buf: None, }; let ssl_stream = ssl::SslStream::new(ssl, inner)?; Ok::<_, Box>(SslFilter { inner: RefCell::new(ssl_stream), }) })?; poll_fn(|cx| { handle_result(st.filter().inner.borrow_mut().accept(), &st, cx) }) .await?; Ok(st) }) .await .map_err(|_| { io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into() }) .and_then(|item| item) }) } } pub struct SslConnector { ssl: ssl::Ssl, } impl SslConnector { /// Create openssl connector filter factory pub fn new(ssl: ssl::Ssl) -> Self { SslConnector { ssl } } } impl FilterFactory for SslConnector { type Filter = SslFilter; type Error = Box; type Future = Pin, Self::Error>>>>; fn create(self, st: Io) -> Self::Future { Box::pin(async move { let ssl = self.ssl; let pool = st.memory_pool(); let st = st.map_filter(|inner: F| { let read_buf = inner.get_read_buf(); let inner = IoInner { pool, inner, read_buf, write_buf: None, }; let ssl_stream = ssl::SslStream::new(ssl, inner)?; Ok::<_, Box>(SslFilter { inner: RefCell::new(ssl_stream), }) })?; poll_fn(|cx| handle_result(st.filter().inner.borrow_mut().connect(), &st, cx)) .await?; Ok(st) }) } } fn handle_result( result: Result, io: &Io, cx: &mut Context<'_>, ) -> Poll>> { match result { Ok(v) => Poll::Ready(Ok(v)), Err(e) => match e.code() { ssl::ErrorCode::WANT_READ => { match ready!(io.poll_read_ready(cx)) { Ok(None) => Err::<_, Box>( io::Error::new(io::ErrorKind::Other, "disconnected").into(), ), Err(err) => Err(err.into()), _ => Ok(()), }?; Poll::Pending } ssl::ErrorCode::WANT_WRITE => { let _ = io.poll_flush(cx, true)?; Poll::Pending } _ => Poll::Ready(Err(Box::new(e))), }, } } fn map_to_ioerr>>(err: E) -> io::Error { io::Error::new(io::ErrorKind::Other, err) }