mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-03 21:07:39 +03:00
Fix read filters ordering
This commit is contained in:
parent
531bafbae2
commit
079a1c9cbf
23 changed files with 642 additions and 325 deletions
|
@ -1,5 +1,9 @@
|
|||
# Changes
|
||||
|
||||
## [0.1.0-b.6] - 2021-12-xx
|
||||
|
||||
* Fix read filters ordering
|
||||
|
||||
## [0.1.0-b.5] - 2021-12-24
|
||||
|
||||
* Use new ntex-service traits
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex-io"
|
||||
version = "0.1.0-b.5"
|
||||
version = "0.1.0-b.6"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "Utilities for encoding and decoding frames"
|
||||
keywords = ["network", "framework", "async", "futures"]
|
||||
|
|
|
@ -67,7 +67,7 @@ impl Filter for Base {
|
|||
|
||||
if flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||
Poll::Ready(ReadStatus::Terminate)
|
||||
} else if flags.intersects(Flags::RD_PAUSED) {
|
||||
} else if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) {
|
||||
self.0 .0.read_task.register(cx.waker());
|
||||
Poll::Pending
|
||||
} else {
|
||||
|
@ -98,7 +98,7 @@ impl Filter for Base {
|
|||
|
||||
#[inline]
|
||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||
self.0 .0.read_buf.take()
|
||||
None
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -107,21 +107,18 @@ impl Filter for Base {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error> {
|
||||
if nbytes > 0 {
|
||||
if buf.len() > self.0.memory_pool().read_params().high as usize {
|
||||
log::trace!(
|
||||
"buffer is too large {}, enable read back-pressure",
|
||||
buf.len()
|
||||
);
|
||||
self.0 .0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
|
||||
} else {
|
||||
self.0 .0.insert_flags(Flags::RD_READY);
|
||||
}
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
buf: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nbytes: usize,
|
||||
) -> io::Result<usize> {
|
||||
if let Some(ref mut dst) = dst {
|
||||
dst.extend_from_slice(&buf)
|
||||
} else {
|
||||
*dst = Some(buf)
|
||||
}
|
||||
|
||||
self.0 .0.read_buf.set(Some(buf));
|
||||
Ok(())
|
||||
Ok(nbytes)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -178,8 +175,13 @@ impl Filter for NullFilter {
|
|||
None
|
||||
}
|
||||
|
||||
fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<(), io::Error> {
|
||||
Ok(())
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
_: BytesMut,
|
||||
_: &mut Option<BytesMut>,
|
||||
_: usize,
|
||||
) -> io::Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn release_write_buf(&self, _: BytesMut) -> Result<(), io::Error> {
|
||||
|
|
|
@ -279,7 +279,8 @@ impl fmt::Debug for IoRef {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::{future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
|
||||
|
||||
use ntex_bytes::Bytes;
|
||||
use ntex_codec::BytesCodec;
|
||||
|
@ -393,9 +394,12 @@ mod tests {
|
|||
}
|
||||
|
||||
struct Counter<F> {
|
||||
idx: usize,
|
||||
inner: F,
|
||||
in_bytes: Rc<Cell<usize>>,
|
||||
out_bytes: Rc<Cell<usize>>,
|
||||
read_order: Rc<RefCell<Vec<usize>>>,
|
||||
write_order: Rc<RefCell<Vec<usize>>>,
|
||||
}
|
||||
impl<F: Filter> Filter for Counter<F> {
|
||||
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
|
||||
|
@ -425,10 +429,13 @@ mod tests {
|
|||
fn release_read_buf(
|
||||
&self,
|
||||
buf: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
new_bytes: usize,
|
||||
) -> Result<(), io::Error> {
|
||||
self.in_bytes.set(self.in_bytes.get() + new_bytes);
|
||||
self.inner.release_read_buf(buf, new_bytes)
|
||||
) -> io::Result<usize> {
|
||||
let result = self.inner.release_read_buf(buf, dst, new_bytes)?;
|
||||
self.read_order.borrow_mut().push(self.idx);
|
||||
self.in_bytes.set(self.in_bytes.get() + result);
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
|
||||
|
@ -445,12 +452,19 @@ mod tests {
|
|||
}
|
||||
|
||||
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
|
||||
self.write_order.borrow_mut().push(self.idx);
|
||||
self.out_bytes.set(self.out_bytes.get() + buf.len());
|
||||
self.inner.release_write_buf(buf)
|
||||
}
|
||||
}
|
||||
|
||||
struct CounterFactory(Rc<Cell<usize>>, Rc<Cell<usize>>);
|
||||
struct CounterFactory(
|
||||
usize,
|
||||
Rc<Cell<usize>>,
|
||||
Rc<Cell<usize>>,
|
||||
Rc<RefCell<Vec<usize>>>,
|
||||
Rc<RefCell<Vec<usize>>>,
|
||||
);
|
||||
|
||||
impl<F: Filter> FilterFactory<F> for CounterFactory {
|
||||
type Filter = Counter<F>;
|
||||
|
@ -459,14 +473,20 @@ mod tests {
|
|||
type Future = Ready<Io<Counter<F>>, Self::Error>;
|
||||
|
||||
fn create(self, io: Io<F>) -> Self::Future {
|
||||
let in_bytes = self.0.clone();
|
||||
let out_bytes = self.1.clone();
|
||||
let idx = self.0;
|
||||
let in_bytes = self.1.clone();
|
||||
let out_bytes = self.2.clone();
|
||||
let read_order = self.3.clone();
|
||||
let write_order = self.4.clone();
|
||||
Ready::Ok(
|
||||
io.map_filter(|inner| {
|
||||
Ok::<_, ()>(Counter {
|
||||
idx,
|
||||
inner,
|
||||
in_bytes,
|
||||
out_bytes,
|
||||
read_order,
|
||||
write_order,
|
||||
})
|
||||
})
|
||||
.unwrap(),
|
||||
|
@ -478,7 +498,15 @@ mod tests {
|
|||
async fn filter() {
|
||||
let in_bytes = Rc::new(Cell::new(0));
|
||||
let out_bytes = Rc::new(Cell::new(0));
|
||||
let factory = CounterFactory(in_bytes.clone(), out_bytes.clone());
|
||||
let read_order = Rc::new(RefCell::new(Vec::new()));
|
||||
let write_order = Rc::new(RefCell::new(Vec::new()));
|
||||
let factory = CounterFactory(
|
||||
1,
|
||||
in_bytes.clone(),
|
||||
out_bytes.clone(),
|
||||
read_order.clone(),
|
||||
write_order.clone(),
|
||||
);
|
||||
|
||||
let (client, server) = IoTest::create();
|
||||
let state = Io::new(server).add_filter(factory).await.unwrap();
|
||||
|
@ -503,13 +531,27 @@ mod tests {
|
|||
async fn boxed_filter() {
|
||||
let in_bytes = Rc::new(Cell::new(0));
|
||||
let out_bytes = Rc::new(Cell::new(0));
|
||||
let read_order = Rc::new(RefCell::new(Vec::new()));
|
||||
let write_order = Rc::new(RefCell::new(Vec::new()));
|
||||
|
||||
let (client, server) = IoTest::create();
|
||||
let state = Io::new(server)
|
||||
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
|
||||
.add_filter(CounterFactory(
|
||||
1,
|
||||
in_bytes.clone(),
|
||||
out_bytes.clone(),
|
||||
read_order.clone(),
|
||||
write_order.clone(),
|
||||
))
|
||||
.await
|
||||
.unwrap()
|
||||
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
|
||||
.add_filter(CounterFactory(
|
||||
2,
|
||||
in_bytes.clone(),
|
||||
out_bytes.clone(),
|
||||
read_order.clone(),
|
||||
write_order.clone(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
let state = state.seal();
|
||||
|
@ -533,5 +575,7 @@ mod tests {
|
|||
assert_eq!(Rc::strong_count(&in_bytes), 3);
|
||||
drop(state);
|
||||
assert_eq!(Rc::strong_count(&in_bytes), 1);
|
||||
assert_eq!(*read_order.borrow(), &[1, 2][..]);
|
||||
assert_eq!(*write_order.borrow(), &[2, 1][..]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::{
|
||||
any::Any, any::TypeId, fmt, future::Future, io::Error as IoError, task::Context,
|
||||
task::Poll,
|
||||
any::Any, any::TypeId, fmt, future::Future, io as sio, io::Error as IoError,
|
||||
task::Context, task::Poll,
|
||||
};
|
||||
|
||||
pub mod testing;
|
||||
|
@ -55,7 +55,7 @@ pub trait Filter: 'static {
|
|||
/// Filter wants gracefully shutdown io stream
|
||||
fn want_shutdown(&self);
|
||||
|
||||
fn poll_shutdown(&self) -> Poll<std::io::Result<()>>;
|
||||
fn poll_shutdown(&self) -> Poll<sio::Result<()>>;
|
||||
|
||||
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
|
||||
|
||||
|
@ -65,11 +65,16 @@ pub trait Filter: 'static {
|
|||
|
||||
fn get_write_buf(&self) -> Option<BytesMut>;
|
||||
|
||||
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> std::io::Result<()>;
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nbytes: usize,
|
||||
) -> sio::Result<usize>;
|
||||
|
||||
fn release_write_buf(&self, buf: BytesMut) -> std::io::Result<()>;
|
||||
fn release_write_buf(&self, buf: BytesMut) -> sio::Result<()>;
|
||||
|
||||
fn closed(&self, err: Option<std::io::Error>);
|
||||
fn closed(&self, err: Option<sio::Error>);
|
||||
}
|
||||
|
||||
pub trait FilterFactory<F: Filter>: Sized {
|
||||
|
|
|
@ -31,27 +31,35 @@ impl ReadContext {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
pub fn release_read_buf(
|
||||
&self,
|
||||
buf: BytesMut,
|
||||
new_bytes: usize,
|
||||
) -> Result<(), io::Error> {
|
||||
pub fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error> {
|
||||
if buf.is_empty() {
|
||||
self.0.memory_pool().release_read_buf(buf);
|
||||
Ok(())
|
||||
} else {
|
||||
let mut flags = self.0.flags();
|
||||
if new_bytes > 0 {
|
||||
flags.insert(Flags::RD_READY);
|
||||
self.0.set_flags(flags);
|
||||
let mut dst = self.0 .0.read_buf.take();
|
||||
let nbytes = self.0.filter().release_read_buf(buf, &mut dst, nbytes)?;
|
||||
|
||||
if let Some(dst) = dst {
|
||||
if self.0.flags().contains(Flags::IO_FILTERS) {
|
||||
self.0 .0.shutdown_filters()?;
|
||||
}
|
||||
if nbytes > 0 {
|
||||
if dst.len() > self.0.memory_pool().read_params().high as usize {
|
||||
log::trace!(
|
||||
"buffer is too large {}, enable read back-pressure",
|
||||
dst.len()
|
||||
);
|
||||
self.0 .0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
|
||||
} else {
|
||||
self.0 .0.insert_flags(Flags::RD_READY);
|
||||
log::trace!("new {} bytes available, wakeup dispatcher", nbytes);
|
||||
}
|
||||
self.0 .0.dispatch_task.wake();
|
||||
}
|
||||
self.0 .0.read_buf.set(Some(dst));
|
||||
} else if nbytes > 0 {
|
||||
self.0 .0.dispatch_task.wake();
|
||||
log::trace!("new {} bytes available, wakeup dispatcher", new_bytes);
|
||||
}
|
||||
|
||||
self.0.filter().release_read_buf(buf, new_bytes)?;
|
||||
|
||||
if flags.contains(Flags::IO_FILTERS) {
|
||||
self.0 .0.shutdown_filters()?;
|
||||
self.0 .0.insert_flags(Flags::RD_READY);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -51,62 +51,72 @@ impl Future for ReadTask {
|
|||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.as_ref();
|
||||
|
||||
match this.state.poll_ready(cx) {
|
||||
Poll::Ready(ReadStatus::Ready) => {
|
||||
let pool = this.state.memory_pool();
|
||||
let mut io = this.io.borrow_mut();
|
||||
let mut buf = self.state.get_read_buf();
|
||||
let (hw, lw) = pool.read_params().unpack();
|
||||
loop {
|
||||
match ready!(this.state.poll_ready(cx)) {
|
||||
ReadStatus::Ready => {
|
||||
let pool = this.state.memory_pool();
|
||||
let mut io = this.io.borrow_mut();
|
||||
let mut buf = self.state.get_read_buf();
|
||||
let (hw, lw) = pool.read_params().unpack();
|
||||
|
||||
// read data from socket
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
}
|
||||
// read data from socket
|
||||
let mut new_bytes = 0;
|
||||
let mut close = false;
|
||||
let mut pending = false;
|
||||
loop {
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
}
|
||||
|
||||
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!("io stream is disconnected");
|
||||
if let Err(e) = this.state.release_read_buf(buf, new_bytes)
|
||||
{
|
||||
this.state.close(Some(e));
|
||||
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
|
||||
Poll::Pending => {
|
||||
pending = true;
|
||||
break;
|
||||
}
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!("tokio stream is disconnected");
|
||||
close = true;
|
||||
} else {
|
||||
this.state.close(None);
|
||||
new_bytes += n;
|
||||
if new_bytes < hw {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::trace!("read task failed on io {:?}", err);
|
||||
let _ = this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(Some(err));
|
||||
return Poll::Ready(());
|
||||
} else {
|
||||
new_bytes += n;
|
||||
if buf.len() > hw {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::trace!("read task failed on io {:?}", err);
|
||||
let _ = this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(Some(err));
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
|
||||
this.state.close(Some(e));
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
Poll::Pending
|
||||
if new_bytes == 0 && close {
|
||||
this.state.close(None);
|
||||
return Poll::Ready(());
|
||||
}
|
||||
return if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
|
||||
this.state.close(Some(e));
|
||||
Poll::Ready(())
|
||||
} else if close {
|
||||
this.state.close(None);
|
||||
Poll::Ready(())
|
||||
} else if pending {
|
||||
Poll::Pending
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
}
|
||||
ReadStatus::Terminate => {
|
||||
log::trace!("read task is instructed to shutdown");
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
Poll::Ready(ReadStatus::Terminate) => {
|
||||
log::trace!("read task is instructed to shutdown");
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -466,63 +476,72 @@ mod unixstream {
|
|||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.as_ref();
|
||||
|
||||
match this.state.poll_ready(cx) {
|
||||
Poll::Ready(ReadStatus::Ready) => {
|
||||
let pool = this.state.memory_pool();
|
||||
let mut io = this.io.borrow_mut();
|
||||
let mut buf = self.state.get_read_buf();
|
||||
let (hw, lw) = pool.read_params().unpack();
|
||||
loop {
|
||||
match ready!(this.state.poll_ready(cx)) {
|
||||
ReadStatus::Ready => {
|
||||
let pool = this.state.memory_pool();
|
||||
let mut io = this.io.borrow_mut();
|
||||
let mut buf = self.state.get_read_buf();
|
||||
let (hw, lw) = pool.read_params().unpack();
|
||||
|
||||
// read data from socket
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
}
|
||||
// read data from socket
|
||||
let mut new_bytes = 0;
|
||||
let mut close = false;
|
||||
let mut pending = false;
|
||||
loop {
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
}
|
||||
|
||||
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!("io stream is disconnected");
|
||||
if let Err(e) =
|
||||
this.state.release_read_buf(buf, new_bytes)
|
||||
{
|
||||
this.state.close(Some(e));
|
||||
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
|
||||
Poll::Pending => {
|
||||
pending = true;
|
||||
break;
|
||||
}
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!("unix stream is disconnected");
|
||||
close = true;
|
||||
} else {
|
||||
this.state.close(None);
|
||||
new_bytes += n;
|
||||
if new_bytes < hw {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::trace!("read task failed on io {:?}", err);
|
||||
let _ = this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(Some(err));
|
||||
return Poll::Ready(());
|
||||
} else {
|
||||
new_bytes += n;
|
||||
if buf.len() > hw {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::trace!("read task failed on io {:?}", err);
|
||||
let _ = this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(Some(err));
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
|
||||
this.state.close(Some(e));
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
Poll::Pending
|
||||
if new_bytes == 0 && close {
|
||||
this.state.close(None);
|
||||
return Poll::Ready(());
|
||||
}
|
||||
return if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
|
||||
this.state.close(Some(e));
|
||||
Poll::Ready(())
|
||||
} else if close {
|
||||
this.state.close(None);
|
||||
Poll::Ready(())
|
||||
} else if pending {
|
||||
Poll::Pending
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
}
|
||||
ReadStatus::Terminate => {
|
||||
log::trace!("read task is instructed to shutdown");
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
Poll::Ready(ReadStatus::Terminate) => {
|
||||
log::trace!("read task is instructed to shutdown");
|
||||
Poll::Ready(())
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex-tls"
|
||||
version = "0.1.0-b.3"
|
||||
version = "0.1.0-b.4"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "An implementation of SSL streams for ntex backed by OpenSSL"
|
||||
keywords = ["network", "framework", "async", "futures"]
|
||||
|
@ -26,7 +26,7 @@ rustls = ["tls_rust"]
|
|||
|
||||
[dependencies]
|
||||
ntex-bytes = "0.1.8"
|
||||
ntex-io = "0.1.0-b.5"
|
||||
ntex-io = "0.1.0-b.6"
|
||||
ntex-util = "0.1.4"
|
||||
ntex-service = "0.3.0-b.0"
|
||||
pin-project-lite = "0.2"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#![allow(clippy::type_complexity)]
|
||||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::cell::RefCell;
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::{
|
||||
any, cmp, error::Error, future::Future, io, pin::Pin, task::Context, task::Poll,
|
||||
};
|
||||
|
@ -18,6 +18,7 @@ use super::types;
|
|||
/// An implementation of SSL streams
|
||||
pub struct SslFilter<F = Base> {
|
||||
inner: RefCell<SslStream<IoInner<F>>>,
|
||||
handshake: Cell<bool>,
|
||||
}
|
||||
|
||||
struct IoInner<F> {
|
||||
|
@ -128,12 +129,15 @@ impl<F: Filter> Filter for SslFilter<F> {
|
|||
|
||||
#[inline]
|
||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||
if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
inner.get_ref().inner.get_read_buf().or_else(|| {
|
||||
if let Some(buf) = inner.get_mut().read_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -146,25 +150,34 @@ impl<F: Filter> Filter for SslFilter<F> {
|
|||
None
|
||||
}
|
||||
|
||||
fn release_read_buf(&self, src: BytesMut, nbytes: usize) -> Result<(), io::Error> {
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nbytes: usize,
|
||||
) -> io::Result<usize> {
|
||||
// store to read_buf
|
||||
let pool = {
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
inner.get_mut().read_buf = Some(src);
|
||||
inner.get_ref().pool
|
||||
let mut dst = None;
|
||||
inner
|
||||
.get_ref()
|
||||
.inner
|
||||
.release_read_buf(src, &mut dst, nbytes)?;
|
||||
if dst.is_some() {
|
||||
inner.get_mut().read_buf = dst;
|
||||
inner.get_ref().pool
|
||||
} else {
|
||||
return Ok(0);
|
||||
}
|
||||
};
|
||||
if nbytes == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let (hw, lw) = pool.read_params().unpack();
|
||||
|
||||
// get inner filter buffer
|
||||
let mut buf = if let Some(buf) = self.inner.borrow().get_ref().inner.get_read_buf()
|
||||
{
|
||||
buf
|
||||
} else {
|
||||
BytesMut::with_capacity_in(lw, pool)
|
||||
};
|
||||
if dst.is_none() {
|
||||
*dst = Some(pool.get_read_buf());
|
||||
}
|
||||
let buf = dst.as_mut().unwrap();
|
||||
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
|
@ -186,11 +199,13 @@ impl<F: Filter> Filter for SslFilter<F> {
|
|||
if e.code() == ssl::ErrorCode::WANT_READ
|
||||
|| e.code() == ssl::ErrorCode::WANT_WRITE =>
|
||||
{
|
||||
self.inner
|
||||
.borrow()
|
||||
.get_ref()
|
||||
.inner
|
||||
.release_read_buf(buf, new_bytes)
|
||||
if new_bytes == 0 && self.handshake.get() {
|
||||
new_bytes = 1;
|
||||
if self.inner.borrow().ssl().is_init_finished() {
|
||||
self.handshake.set(false);
|
||||
}
|
||||
}
|
||||
Ok(new_bytes)
|
||||
}
|
||||
Err(e) => Err(map_to_ioerr(e)),
|
||||
};
|
||||
|
@ -274,6 +289,7 @@ impl<F: Filter> FilterFactory<F> for SslAcceptor {
|
|||
|
||||
Ok::<_, Box<dyn Error>>(SslFilter {
|
||||
inner: RefCell::new(ssl_stream),
|
||||
handshake: Cell::new(true),
|
||||
})
|
||||
})?;
|
||||
|
||||
|
@ -326,6 +342,7 @@ impl<F: Filter> FilterFactory<F> for SslConnector {
|
|||
|
||||
Ok::<_, Box<dyn Error>>(SslFilter {
|
||||
inner: RefCell::new(ssl_stream),
|
||||
handshake: Cell::new(true),
|
||||
})
|
||||
})?;
|
||||
|
||||
|
|
|
@ -78,12 +78,15 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
|
|||
|
||||
#[inline]
|
||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
inner.inner.get_read_buf().or_else(|| {
|
||||
if let Some(buf) = inner.read_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -96,24 +99,36 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
|
|||
None
|
||||
}
|
||||
|
||||
fn release_read_buf(&self, mut src: BytesMut, _nb: usize) -> Result<(), io::Error> {
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nbytes: usize,
|
||||
) -> io::Result<usize> {
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
let mut session = self.session.borrow_mut();
|
||||
|
||||
if session.is_handshaking() {
|
||||
self.inner.borrow_mut().read_buf = Some(src);
|
||||
Ok(())
|
||||
inner.read_buf = Some(src);
|
||||
Ok(1)
|
||||
} else {
|
||||
if src.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
let mut src = {
|
||||
let mut dst = None;
|
||||
inner.inner.release_read_buf(src, &mut dst, nbytes)?;
|
||||
|
||||
if let Some(dst) = dst {
|
||||
dst
|
||||
} else {
|
||||
return Ok(0);
|
||||
}
|
||||
};
|
||||
let (hw, lw) = inner.pool.read_params().unpack();
|
||||
|
||||
// get inner filter buffer
|
||||
let mut buf = if let Some(buf) = inner.inner.get_read_buf() {
|
||||
buf
|
||||
} else {
|
||||
BytesMut::with_capacity_in(lw, inner.pool)
|
||||
};
|
||||
if dst.is_none() {
|
||||
*dst = Some(inner.pool.get_read_buf());
|
||||
}
|
||||
let buf = dst.as_mut().unwrap();
|
||||
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
|
@ -146,7 +161,7 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
|
|||
if !src.is_empty() {
|
||||
inner.read_buf = Some(src);
|
||||
}
|
||||
inner.inner.release_read_buf(buf, new_bytes)
|
||||
Ok(new_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -125,10 +125,15 @@ impl<F: Filter> Filter for TlsFilter<F> {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn release_read_buf(&self, src: BytesMut, nb: usize) -> Result<(), io::Error> {
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nb: usize,
|
||||
) -> io::Result<usize> {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref f) => f.release_read_buf(src, nb),
|
||||
InnerTlsFilter::Client(ref f) => f.release_read_buf(src, nb),
|
||||
InnerTlsFilter::Server(ref f) => f.release_read_buf(src, dst, nb),
|
||||
InnerTlsFilter::Client(ref f) => f.release_read_buf(src, dst, nb),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -78,12 +78,15 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
|||
|
||||
#[inline]
|
||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||
if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
inner.inner.get_read_buf().or_else(|| {
|
||||
if let Some(buf) = inner.read_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
return Some(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -96,25 +99,36 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
|||
None
|
||||
}
|
||||
|
||||
fn release_read_buf(&self, mut src: BytesMut, _nb: usize) -> Result<(), io::Error> {
|
||||
let mut session = self.session.borrow_mut();
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nbytes: usize,
|
||||
) -> io::Result<usize> {
|
||||
let mut inner = self.inner.borrow_mut();
|
||||
let mut session = self.session.borrow_mut();
|
||||
|
||||
if session.is_handshaking() {
|
||||
inner.read_buf = Some(src);
|
||||
Ok(())
|
||||
Ok(1)
|
||||
} else {
|
||||
if src.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let mut src = {
|
||||
let mut dst = None;
|
||||
inner.inner.release_read_buf(src, &mut dst, nbytes)?;
|
||||
|
||||
if let Some(dst) = dst {
|
||||
dst
|
||||
} else {
|
||||
return Ok(0);
|
||||
}
|
||||
};
|
||||
let (hw, lw) = inner.pool.read_params().unpack();
|
||||
|
||||
// get inner filter buffer
|
||||
let mut buf = if let Some(buf) = inner.inner.get_read_buf() {
|
||||
buf
|
||||
} else {
|
||||
BytesMut::with_capacity_in(lw, inner.pool)
|
||||
};
|
||||
if dst.is_none() {
|
||||
*dst = Some(inner.pool.get_read_buf());
|
||||
}
|
||||
let buf = dst.as_mut().unwrap();
|
||||
|
||||
let mut new_bytes = 0;
|
||||
loop {
|
||||
|
@ -147,7 +161,7 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
|
|||
if !src.is_empty() {
|
||||
inner.read_buf = Some(src);
|
||||
}
|
||||
inner.inner.release_read_buf(buf, new_bytes)
|
||||
Ok(new_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
# Changes
|
||||
|
||||
## [0.5.0-b.4] - 2021-12-xx
|
||||
|
||||
* Allow to get access to ws transport codec
|
||||
|
||||
## [0.5.0-b.3] - 2021-12-24
|
||||
|
||||
* Use new ntex-service traits
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex"
|
||||
version = "0.5.0-b.3"
|
||||
version = "0.5.0-b.4"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "Framework for composable network services"
|
||||
readme = "README.md"
|
||||
|
@ -45,8 +45,8 @@ ntex-service = "0.3.0-b.0"
|
|||
ntex-macros = "0.1.3"
|
||||
ntex-util = "0.1.4"
|
||||
ntex-bytes = "0.1.8"
|
||||
ntex-tls = "0.1.0-b.3"
|
||||
ntex-io = "0.1.0-b.5"
|
||||
ntex-tls = "0.1.0-b.4"
|
||||
ntex-io = "0.1.0-b.6"
|
||||
ntex-rt = { version = "0.4.0-b.2", default-features = false, features = ["tokio"] }
|
||||
|
||||
base64 = "0.13"
|
||||
|
|
|
@ -154,11 +154,9 @@ mod tests {
|
|||
let factory = Connector::new(config).clone();
|
||||
|
||||
let srv = factory.new_service(()).await.unwrap();
|
||||
let _result = srv
|
||||
let result = srv
|
||||
.call(Connect::new("www.rust-lang.org").set_addr(Some(server.addr())))
|
||||
.await;
|
||||
|
||||
// TODO! fix
|
||||
// assert!(result.is_err());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -340,7 +340,9 @@ impl TestServer {
|
|||
|
||||
#[cfg(feature = "openssl")]
|
||||
/// Connect to a websocket server
|
||||
pub async fn wss(&mut self) -> Result<WsConnection<impl Filter>, WsClientError> {
|
||||
pub async fn wss(
|
||||
&mut self,
|
||||
) -> Result<WsConnection<crate::connect::openssl::SslFilter>, WsClientError> {
|
||||
self.wss_at("/").await
|
||||
}
|
||||
|
||||
|
@ -349,7 +351,7 @@ impl TestServer {
|
|||
pub async fn wss_at(
|
||||
&mut self,
|
||||
path: &str,
|
||||
) -> Result<WsConnection<impl Filter>, WsClientError> {
|
||||
) -> Result<WsConnection<crate::connect::openssl::SslFilter>, WsClientError> {
|
||||
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
|
||||
|
||||
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||
|
|
|
@ -81,7 +81,9 @@ pub mod io {
|
|||
|
||||
pub mod testing {
|
||||
//! IO testing utilities.
|
||||
#[doc(hidden)]
|
||||
pub use ntex_io::testing::IoTest as Io;
|
||||
pub use ntex_io::testing::IoTest;
|
||||
}
|
||||
|
||||
pub mod tls {
|
||||
|
|
|
@ -57,12 +57,7 @@ struct Inner<F, T> {
|
|||
|
||||
impl WsClient<Base, ()> {
|
||||
/// Create new websocket client builder
|
||||
pub fn build<U>(
|
||||
uri: U,
|
||||
) -> WsClientBuilder<
|
||||
Base,
|
||||
impl Service<Connect<Uri>, Response = Io, Error = ConnectError>,
|
||||
>
|
||||
pub fn build<U>(uri: U) -> WsClientBuilder<Base, Connector<Uri>>
|
||||
where
|
||||
Uri: TryFrom<U>,
|
||||
<Uri as TryFrom<U>>::Error: Into<HttpError>,
|
||||
|
@ -269,12 +264,7 @@ impl<F, T> fmt::Debug for WsClient<F, T> {
|
|||
|
||||
impl WsClientBuilder<Base, ()> {
|
||||
/// Create new websocket connector
|
||||
fn new<U>(
|
||||
uri: U,
|
||||
) -> WsClientBuilder<
|
||||
Base,
|
||||
impl Service<Connect<Uri>, Response = Io, Error = ConnectError>,
|
||||
>
|
||||
fn new<U>(uri: U) -> WsClientBuilder<Base, Connector<Uri>>
|
||||
where
|
||||
Uri: TryFrom<U>,
|
||||
<Uri as TryFrom<U>>::Error: Into<HttpError>,
|
||||
|
@ -673,6 +663,11 @@ impl<F> WsConnection<F> {
|
|||
Self { io, codec, res }
|
||||
}
|
||||
|
||||
/// Get codec reference
|
||||
pub fn codec(&self) -> &ws::Codec {
|
||||
&self.codec
|
||||
}
|
||||
|
||||
/// Get reference to response
|
||||
pub fn response(&self) -> &ClientResponse {
|
||||
&self.res
|
||||
|
|
|
@ -91,7 +91,7 @@ impl<F: Filter> Filter for WsTransport<F> {
|
|||
|
||||
#[inline]
|
||||
fn get_read_buf(&self) -> Option<BytesMut> {
|
||||
self.read_buf.take()
|
||||
self.inner.get_read_buf().or_else(|| self.read_buf.take())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -99,115 +99,127 @@ impl<F: Filter> Filter for WsTransport<F> {
|
|||
None
|
||||
}
|
||||
|
||||
fn release_read_buf(&self, mut src: BytesMut, nbytes: usize) -> Result<(), io::Error> {
|
||||
if nbytes == 0 {
|
||||
if !src.is_empty() {
|
||||
self.read_buf.set(Some(src));
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
let (hw, lw) = self.pool.read_params().unpack();
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
src: BytesMut,
|
||||
dst: &mut Option<BytesMut>,
|
||||
nbytes: usize,
|
||||
) -> io::Result<usize> {
|
||||
let mut src = {
|
||||
let mut dst = None;
|
||||
self.inner.release_read_buf(src, &mut dst, nbytes)?;
|
||||
|
||||
// get inner filter buffer
|
||||
let mut buf = if let Some(buf) = self.inner.get_read_buf() {
|
||||
buf
|
||||
if let Some(dst) = dst {
|
||||
dst
|
||||
} else {
|
||||
self.pool.get_read_buf()
|
||||
};
|
||||
let len = buf.len();
|
||||
let mut flags = self.flags.get();
|
||||
return Ok(0);
|
||||
}
|
||||
};
|
||||
let (hw, lw) = self.pool.read_params().unpack();
|
||||
|
||||
// read from input buffer
|
||||
loop {
|
||||
let result = self.codec.decode(&mut src).map_err(|e| {
|
||||
log::trace!("ws codec failed to decode bytes stream: {:?}", e);
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
})?;
|
||||
// get outter filter buffer
|
||||
if dst.is_none() {
|
||||
*dst = Some(self.pool.get_read_buf());
|
||||
}
|
||||
let buf = dst.as_mut().unwrap();
|
||||
let mut flags = self.flags.get();
|
||||
let mut nbytes = 0;
|
||||
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
}
|
||||
// read from input buffer
|
||||
loop {
|
||||
let result = self.codec.decode(&mut src).map_err(|e| {
|
||||
log::trace!("ws codec failed to decode bytes stream: {:?}", e);
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
})?;
|
||||
|
||||
match result {
|
||||
Some(frame) => match frame {
|
||||
Frame::Binary(bin) => buf.extend_from_slice(&bin),
|
||||
Frame::Continuation(item) => match item {
|
||||
Item::FirstText(_) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"WebSocket text continuation frames are not supported",
|
||||
));
|
||||
}
|
||||
Item::FirstBinary(bin) => {
|
||||
flags = self.insert_flags(Flags::CONTINUATION);
|
||||
// make sure we've got room
|
||||
let remaining = buf.remaining_mut();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
}
|
||||
|
||||
match result {
|
||||
Some(frame) => match frame {
|
||||
Frame::Binary(bin) => {
|
||||
nbytes += bin.len();
|
||||
buf.extend_from_slice(&bin)
|
||||
}
|
||||
Frame::Continuation(item) => match item {
|
||||
Item::FirstText(_) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"WebSocket text continuation frames are not supported",
|
||||
));
|
||||
}
|
||||
Item::FirstBinary(bin) => {
|
||||
nbytes += bin.len();
|
||||
buf.extend_from_slice(&bin);
|
||||
flags = self.insert_flags(Flags::CONTINUATION);
|
||||
}
|
||||
Item::Continue(bin) => {
|
||||
if flags.contains(Flags::CONTINUATION) {
|
||||
nbytes += bin.len();
|
||||
buf.extend_from_slice(&bin);
|
||||
}
|
||||
Item::Continue(bin) => {
|
||||
if flags.contains(Flags::CONTINUATION) {
|
||||
buf.extend_from_slice(&bin);
|
||||
} else {
|
||||
return Err(io::Error::new(
|
||||
} else {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Continuation frame must follow data frame with FIN bit clear",
|
||||
));
|
||||
}
|
||||
}
|
||||
Item::Last(bin) => {
|
||||
if flags.contains(Flags::CONTINUATION) {
|
||||
flags = self.remove_flags(Flags::CONTINUATION);
|
||||
buf.extend_from_slice(&bin);
|
||||
} else {
|
||||
return Err(io::Error::new(
|
||||
}
|
||||
Item::Last(bin) => {
|
||||
if flags.contains(Flags::CONTINUATION) {
|
||||
nbytes += bin.len();
|
||||
buf.extend_from_slice(&bin);
|
||||
flags = self.remove_flags(Flags::CONTINUATION);
|
||||
} else {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"Received last frame without initial continuation frame",
|
||||
));
|
||||
}
|
||||
}
|
||||
},
|
||||
Frame::Text(_) => {
|
||||
log::trace!("WebSocket text frames are not supported");
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"WebSocket text frames are not supported",
|
||||
));
|
||||
}
|
||||
Frame::Ping(msg) => {
|
||||
let mut b = self
|
||||
.inner
|
||||
.get_write_buf()
|
||||
.unwrap_or_else(|| self.pool.get_write_buf());
|
||||
self.codec
|
||||
.encode(Message::Pong(msg), &mut b)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
self.release_write_buf(b)?;
|
||||
}
|
||||
Frame::Pong(_) => (),
|
||||
Frame::Close(_) => {
|
||||
let mut b = self
|
||||
.inner
|
||||
.get_write_buf()
|
||||
.unwrap_or_else(|| self.pool.get_write_buf());
|
||||
self.codec
|
||||
.encode(Message::Close(None), &mut b)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
self.release_write_buf(b)?;
|
||||
break;
|
||||
}
|
||||
},
|
||||
None => break,
|
||||
}
|
||||
Frame::Text(_) => {
|
||||
log::trace!("WebSocket text frames are not supported");
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"WebSocket text frames are not supported",
|
||||
));
|
||||
}
|
||||
Frame::Ping(msg) => {
|
||||
let mut b = self
|
||||
.inner
|
||||
.get_write_buf()
|
||||
.unwrap_or_else(|| self.pool.get_write_buf());
|
||||
self.codec
|
||||
.encode(Message::Pong(msg), &mut b)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
self.release_write_buf(b)?;
|
||||
}
|
||||
Frame::Pong(_) => (),
|
||||
Frame::Close(_) => {
|
||||
let mut b = self
|
||||
.inner
|
||||
.get_write_buf()
|
||||
.unwrap_or_else(|| self.pool.get_write_buf());
|
||||
self.codec
|
||||
.encode(Message::Close(None), &mut b)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
self.release_write_buf(b)?;
|
||||
break;
|
||||
}
|
||||
},
|
||||
None => break,
|
||||
}
|
||||
|
||||
if !src.is_empty() {
|
||||
self.read_buf.set(Some(src));
|
||||
} else {
|
||||
self.pool.release_read_buf(src);
|
||||
}
|
||||
let new_bytes = buf.len() - len;
|
||||
self.inner.release_read_buf(buf, new_bytes)
|
||||
}
|
||||
|
||||
if !src.is_empty() {
|
||||
self.read_buf.set(Some(src));
|
||||
} else {
|
||||
self.pool.release_read_buf(src);
|
||||
}
|
||||
Ok(nbytes)
|
||||
}
|
||||
|
||||
fn release_write_buf(&self, src: BytesMut) -> Result<(), io::Error> {
|
||||
|
@ -253,3 +265,39 @@ impl<F: Filter> FilterFactory<F> for WsTransportFactory {
|
|||
Ready::from(st.map_filter(|inner| Ok(WsTransport::new(inner, self.codec, pool))))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{codec::BytesCodec, io::Io, testing::IoTest, util::Bytes};
|
||||
|
||||
#[crate::rt_test]
|
||||
async fn basics() {
|
||||
let (client, server) = IoTest::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
server.remote_buffer_cap(1024);
|
||||
|
||||
let client = WsTransportFactory::new(Codec::new().client_mode())
|
||||
.create(Io::new(client))
|
||||
.await
|
||||
.unwrap();
|
||||
let server = WsTransportFactory::new(Codec::new())
|
||||
.create(Io::new(server))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
client
|
||||
.send(&BytesCodec, Bytes::from_static(b"DATA"))
|
||||
.await
|
||||
.unwrap();
|
||||
let res = server.recv(&BytesCodec).await.unwrap().unwrap();
|
||||
assert_eq!(res, b"DATA".as_ref());
|
||||
|
||||
server
|
||||
.send(&BytesCodec, Bytes::from_static(b"DATA"))
|
||||
.await
|
||||
.unwrap();
|
||||
let res = client.recv(&BytesCodec).await.unwrap().unwrap();
|
||||
assert_eq!(res, b"DATA".as_ref());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,13 +5,16 @@ use futures::future::{err, ok, ready};
|
|||
use futures::stream::{once, Stream, StreamExt};
|
||||
use tls_openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod};
|
||||
|
||||
use ntex::codec::BytesCodec;
|
||||
use ntex::http::error::PayloadError;
|
||||
use ntex::http::header::{self, HeaderName, HeaderValue};
|
||||
use ntex::http::test::server as test_server;
|
||||
use ntex::http::{body, HttpService, Method, Request, Response, StatusCode, Version};
|
||||
use ntex::http::ws::handshake_response;
|
||||
use ntex::http::{body, h1, HttpService, Method, Request, Response, StatusCode, Version};
|
||||
use ntex::io::Io;
|
||||
use ntex::service::{fn_service, ServiceFactory};
|
||||
use ntex::util::{Bytes, BytesMut};
|
||||
use ntex::{time::Seconds, web::error::InternalError};
|
||||
use ntex::util::{Bytes, BytesMut, Ready};
|
||||
use ntex::{time::Seconds, web::error::InternalError, ws};
|
||||
|
||||
async fn load_body<S>(stream: S) -> Result<BytesMut, PayloadError>
|
||||
where
|
||||
|
@ -435,3 +438,49 @@ async fn test_ssl_handshake_timeout() {
|
|||
let _ = stream.read_to_string(&mut data);
|
||||
assert!(data.is_empty());
|
||||
}
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_ws_transport() {
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.upgrade(|(req, io, codec): (Request, Io<_>, h1::Codec)| {
|
||||
async move {
|
||||
let res = handshake_response(req.head()).finish();
|
||||
|
||||
// send handshake respone
|
||||
io.encode(
|
||||
h1::Message::Item((res.drop_body(), body::BodySize::None)),
|
||||
&codec,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let io = io
|
||||
.add_filter(ws::WsTransportFactory::new(ws::Codec::default()))
|
||||
.await?;
|
||||
|
||||
// start websocket service
|
||||
loop {
|
||||
if let Some(item) =
|
||||
io.recv(&BytesCodec).await.map_err(|e| e.into_inner())?
|
||||
{
|
||||
io.send(&BytesCodec, item.freeze()).await.unwrap()
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
}
|
||||
})
|
||||
.finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound()))
|
||||
.openssl(ssl_acceptor())
|
||||
});
|
||||
|
||||
let io = srv.wss().await.unwrap().into_inner().0;
|
||||
let codec = ws::Codec::default().client_mode();
|
||||
|
||||
io.send(&codec, ws::Message::Binary(Bytes::from_static(b"text")))
|
||||
.await
|
||||
.unwrap();
|
||||
let item = io.recv(&codec).await.unwrap().unwrap();
|
||||
assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text")));
|
||||
}
|
||||
|
|
|
@ -2,6 +2,9 @@ use std::sync::{Arc, Mutex};
|
|||
use std::task::{Context, Poll};
|
||||
use std::{cell::Cell, future::Future, io, pin::Pin};
|
||||
|
||||
use ntex::codec::BytesCodec;
|
||||
use ntex::http::test::server as test_server;
|
||||
use ntex::http::ws::handshake_response;
|
||||
use ntex::http::{
|
||||
body, h1, test, ws::handshake, HttpService, Request, Response, StatusCode,
|
||||
};
|
||||
|
@ -239,3 +242,49 @@ async fn test_simple() {
|
|||
|
||||
assert!(ws_service.was_polled());
|
||||
}
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_transport() {
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.upgrade(|(req, io, codec): (Request, Io, h1::Codec)| {
|
||||
async move {
|
||||
let res = handshake_response(req.head()).finish();
|
||||
|
||||
// send handshake respone
|
||||
io.encode(
|
||||
h1::Message::Item((res.drop_body(), body::BodySize::None)),
|
||||
&codec,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let io = io
|
||||
.add_filter(ws::WsTransportFactory::new(ws::Codec::default()))
|
||||
.await?;
|
||||
|
||||
// start websocket service
|
||||
loop {
|
||||
if let Some(item) =
|
||||
io.recv(&BytesCodec).await.map_err(|e| e.into_inner())?
|
||||
{
|
||||
io.send(&BytesCodec, item.freeze()).await.unwrap()
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
}
|
||||
})
|
||||
.finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound()))
|
||||
});
|
||||
|
||||
// client service
|
||||
let io = srv.ws().await.unwrap().into_inner().0;
|
||||
|
||||
let codec = ws::Codec::default().client_mode();
|
||||
io.send(&codec, ws::Message::Binary(Bytes::from_static(b"text")))
|
||||
.await
|
||||
.unwrap();
|
||||
let item = io.recv(&codec).await.unwrap().unwrap();
|
||||
assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text")));
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use std::io;
|
||||
|
||||
use ntex::codec::BytesCodec;
|
||||
use ntex::http::test::server as test_server;
|
||||
use ntex::http::ws::handshake_response;
|
||||
use ntex::http::{body::BodySize, h1, HttpService, Request, Response};
|
||||
|
@ -79,3 +80,39 @@ async fn test_simple() {
|
|||
let item = io.recv(&codec).await.unwrap().unwrap();
|
||||
assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into())));
|
||||
}
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_transport() {
|
||||
env_logger::init();
|
||||
let mut srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.upgrade(|(req, io, codec): (Request, Io, h1::Codec)| {
|
||||
async move {
|
||||
let res = handshake_response(req.head()).finish();
|
||||
|
||||
// send handshake respone
|
||||
io.encode(h1::Message::Item((res.drop_body(), BodySize::None)), &codec)
|
||||
.unwrap();
|
||||
|
||||
// start websocket service
|
||||
Dispatcher::new(
|
||||
io.seal(),
|
||||
ws::Codec::default(),
|
||||
ws_service,
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
})
|
||||
.finish(|_| Ready::Ok::<_, io::Error>(Response::NotFound()))
|
||||
});
|
||||
|
||||
// client service
|
||||
let io = srv.ws().await.unwrap().into_transport().await;
|
||||
|
||||
io.send(&BytesCodec, Bytes::from_static(b"text"))
|
||||
.await
|
||||
.unwrap();
|
||||
let item = io.recv(&BytesCodec).await.unwrap().unwrap();
|
||||
assert_eq!(item, Bytes::from_static(b"text"));
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use std::{sync::mpsc, thread, time::Duration};
|
|||
use tls_openssl::ssl::SslAcceptorBuilder;
|
||||
|
||||
use ntex::web::{self, App, HttpResponse, HttpServer};
|
||||
use ntex::{rt, server::TestServer, time::Seconds};
|
||||
use ntex::{rt, server::TestServer, time::sleep, time::Seconds};
|
||||
|
||||
#[cfg(unix)]
|
||||
#[ntex::test]
|
||||
|
@ -203,7 +203,7 @@ async fn test_rustls() {
|
|||
// stop
|
||||
let _ = srv.stop(false);
|
||||
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
sys.stop();
|
||||
}
|
||||
|
||||
|
@ -253,7 +253,7 @@ async fn test_bind_uds() {
|
|||
// stop
|
||||
let _ = srv.stop(false);
|
||||
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
sys.stop();
|
||||
}
|
||||
|
||||
|
@ -305,6 +305,6 @@ async fn test_listen_uds() {
|
|||
// stop
|
||||
let _ = srv.stop(false);
|
||||
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
sleep(Duration::from_millis(100)).await;
|
||||
sys.stop();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue