Release write buf after error

This commit is contained in:
Nikolay Kim 2021-12-26 19:24:13 +06:00
parent c801a9ea57
commit 19c6a2b731
15 changed files with 78 additions and 95 deletions

View file

@ -85,7 +85,7 @@ where
&self, &self,
item: <U as Encoder>::Item, item: <U as Encoder>::Item,
) -> Result<(), Either<U::Error, io::Error>> { ) -> Result<(), Either<U::Error, io::Error>> {
self.io.send(&self.codec, item).await self.io.send(item, &self.codec).await
} }
} }

View file

@ -459,22 +459,18 @@ impl<F> Io<F> {
/// Encode item, send to a peer /// Encode item, send to a peer
pub async fn send<U>( pub async fn send<U>(
&self, &self,
codec: &U,
item: U::Item, item: U::Item,
codec: &U,
) -> Result<(), Either<U::Error, io::Error>> ) -> Result<(), Either<U::Error, io::Error>>
where where
U: Encoder, U: Encoder,
{ {
let filter = self.filter(); self.encode(item, codec).map_err(Either::Left)?;
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
codec.encode(item, &mut buf).map_err(Either::Left)?;
filter.release_write_buf(buf).map_err(Either::Right)?;
poll_fn(|cx| self.poll_flush(cx, true)) poll_fn(|cx| self.poll_flush(cx, true))
.await .await
.map_err(Either::Right)?; .map_err(Either::Right)?;
Ok(()) Ok(())
} }

View file

@ -145,8 +145,12 @@ impl IoRef {
let mut buf = filter let mut buf = filter
.get_write_buf() .get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf()); .unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
let result = f(&mut buf); let result = f(&mut buf);
if is_write_sleep {
self.0.write_task.wake();
}
filter.release_write_buf(buf)?; filter.release_write_buf(buf)?;
Ok(result) Ok(result)
} }
@ -177,29 +181,28 @@ impl IoRef {
let flags = self.0.flags.get(); let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let filter = self.0.filter.get(); self.with_write_buf(|buf| {
let mut buf = filter let (hw, lw) = self.memory_pool().write_params().unpack();
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
let (hw, lw) = self.memory_pool().write_params().unpack();
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = buf.remaining_mut();
if remaining < lw { if remaining < lw {
buf.reserve(hw - remaining); buf.reserve(hw - remaining);
} }
// encode item and wake write task // encode item and wake write task
codec.encode(item, &mut buf)?; codec.encode(item, buf)
if is_write_sleep { })
self.0.write_task.wake(); .map_or_else(
} |err| {
if let Err(err) = filter.release_write_buf(buf) { self.0.set_error(Some(err));
self.0.set_error(Some(err)); Ok(())
} },
|item| item,
)
} else {
Ok(())
} }
Ok(())
} }
#[inline] #[inline]
@ -221,31 +224,15 @@ impl IoRef {
#[inline] #[inline]
/// Write bytes to a buffer and wake up write task /// Write bytes to a buffer and wake up write task
/// pub fn write(&self, src: &[u8]) -> io::Result<()> {
/// Returns write buffer state, false is returned if write buffer if full.
pub fn write(&self, src: &[u8]) -> Result<bool, io::Error> {
let flags = self.0.flags.get(); let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let filter = self.0.filter.get(); self.with_write_buf(|buf| {
let mut buf = filter buf.extend_from_slice(src);
.get_write_buf() })
.unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
// write and wake write task
buf.extend_from_slice(src);
let result = buf.len() < self.memory_pool().write_params_high();
if is_write_sleep {
self.0.write_task.wake();
}
if let Err(err) = filter.release_write_buf(buf) {
self.0.set_error(Some(err));
}
Ok(result)
} else { } else {
Ok(true) Ok(())
} }
} }
} }
@ -318,14 +305,14 @@ mod tests {
client.remote_buffer_cap(1024); client.remote_buffer_cap(1024);
let state = Io::new(server); let state = Io::new(server);
state state
.send(&BytesCodec, Bytes::from_static(b"test")) .send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
let buf = client.read().await.unwrap(); let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test")); assert_eq!(buf, Bytes::from_static(b"test"));
client.write_error(io::Error::new(io::ErrorKind::Other, "err")); client.write_error(io::Error::new(io::ErrorKind::Other, "err"));
let res = state.send(&BytesCodec, Bytes::from_static(b"test")).await; let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
assert!(res.is_err()); assert!(res.is_err());
assert!(state.flags().contains(Flags::IO_ERR)); assert!(state.flags().contains(Flags::IO_ERR));
@ -496,7 +483,7 @@ mod tests {
assert_eq!(msg, Bytes::from_static(BIN)); assert_eq!(msg, Bytes::from_static(BIN));
state state
.send(&BytesCodec, Bytes::from_static(b"test")) .send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
let buf = client.read().await.unwrap(); let buf = client.read().await.unwrap();
@ -541,7 +528,7 @@ mod tests {
assert_eq!(msg, Bytes::from_static(BIN)); assert_eq!(msg, Bytes::from_static(BIN));
state state
.send(&BytesCodec, Bytes::from_static(b"test")) .send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
let buf = client.read().await.unwrap(); let buf = client.read().await.unwrap();

View file

@ -20,7 +20,7 @@ async fn main() -> io::Result<()> {
let io = connector.connect("127.0.0.1:8443").await.unwrap(); let io = connector.connect("127.0.0.1:8443").await.unwrap();
println!("Connected to ssl server"); println!("Connected to ssl server");
let result = io let result = io
.send(&codec::BytesCodec, Bytes::from_static(b"hello")) .send(Bytes::from_static(b"hello"), &codec::BytesCodec)
.await .await
.map_err(Either::into_inner)?; .map_err(Either::into_inner)?;

View file

@ -40,8 +40,8 @@ async fn main() -> io::Result<()> {
println!("New client is connected"); println!("New client is connected");
io.send( io.send(
&codec::BytesCodec,
ntex_bytes::Bytes::from_static(b"Welcome!\n"), ntex_bytes::Bytes::from_static(b"Welcome!\n"),
&codec::BytesCodec,
) )
.await .await
.map_err(Either::into_inner)?; .map_err(Either::into_inner)?;
@ -50,7 +50,7 @@ async fn main() -> io::Result<()> {
match io.recv(&codec::BytesCodec).await { match io.recv(&codec::BytesCodec).await {
Ok(Some(msg)) => { Ok(Some(msg)) => {
println!("Got message: {:?}", msg); println!("Got message: {:?}", msg);
io.send(&codec::BytesCodec, msg.freeze()) io.send(msg.freeze(), &codec::BytesCodec)
.await .await
.map_err(Either::into_inner)?; .map_err(Either::into_inner)?;
} }

View file

@ -32,7 +32,7 @@ async fn main() -> io::Result<()> {
match io.recv(&codec::BytesCodec).await { match io.recv(&codec::BytesCodec).await {
Ok(Some(msg)) => { Ok(Some(msg)) => {
println!("Got message: {:?}", msg); println!("Got message: {:?}", msg);
io.send(&codec::BytesCodec, msg.freeze()) io.send(msg.freeze(), &codec::BytesCodec)
.await .await
.map_err(Either::into_inner)?; .map_err(Either::into_inner)?;
} }

View file

@ -59,7 +59,7 @@ where
// send request // send request
let codec = h1::ClientCodec::default(); let codec = h1::ClientCodec::default();
io.send(&codec, (head, body.size()).into()).await?; io.send((head, body.size()).into(), &codec).await?;
log::trace!("http1 request has been sent"); log::trace!("http1 request has been sent");

View file

@ -164,8 +164,8 @@ where
// send request and read response // send request and read response
let fut = async { let fut = async {
io.send( io.send(
&codec,
(RequestHeadType::Rc(head, Some(headers)), BodySize::None).into(), (RequestHeadType::Rc(head, Some(headers)), BodySize::None).into(),
&codec,
) )
.await?; .await?;
io.recv(&codec) io.recv(&codec)

View file

@ -287,14 +287,14 @@ mod tests {
.unwrap(); .unwrap();
client client
.send(&BytesCodec, Bytes::from_static(b"DATA")) .send(Bytes::from_static(b"DATA"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
let res = server.recv(&BytesCodec).await.unwrap().unwrap(); let res = server.recv(&BytesCodec).await.unwrap().unwrap();
assert_eq!(res, b"DATA".as_ref()); assert_eq!(res, b"DATA".as_ref());
server server
.send(&BytesCodec, Bytes::from_static(b"DATA")) .send(Bytes::from_static(b"DATA"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
let res = client.recv(&BytesCodec).await.unwrap().unwrap(); let res = client.recv(&BytesCodec).await.unwrap().unwrap();

View file

@ -12,7 +12,7 @@ use ntex::util::Bytes;
async fn test_string() { async fn test_string() {
let srv = test_server(|| { let srv = test_server(|| {
fn_service(|io: Io| async move { fn_service(|io: Io| async move {
io.send(&BytesCodec, Bytes::from_static(b"test")) io.send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
Ok::<_, io::Error>(()) Ok::<_, io::Error>(())
@ -30,7 +30,7 @@ async fn test_string() {
async fn test_rustls_string() { async fn test_rustls_string() {
let srv = test_server(|| { let srv = test_server(|| {
fn_service(|io: Io| async move { fn_service(|io: Io| async move {
io.send(&BytesCodec, Bytes::from_static(b"test")) io.send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
Ok::<_, io::Error>(()) Ok::<_, io::Error>(())
@ -47,7 +47,7 @@ async fn test_rustls_string() {
async fn test_static_str() { async fn test_static_str() {
let srv = test_server(|| { let srv = test_server(|| {
fn_service(|io: Io| async move { fn_service(|io: Io| async move {
io.send(&BytesCodec, Bytes::from_static(b"test")) io.send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
Ok::<_, io::Error>(()) Ok::<_, io::Error>(())
@ -69,7 +69,7 @@ async fn test_static_str() {
async fn test_new_service() { async fn test_new_service() {
let srv = test_server(|| { let srv = test_server(|| {
fn_service(|io: Io| async move { fn_service(|io: Io| async move {
io.send(&BytesCodec, Bytes::from_static(b"test")) io.send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
Ok::<_, io::Error>(()) Ok::<_, io::Error>(())
@ -89,7 +89,7 @@ async fn test_uri() {
let srv = test_server(|| { let srv = test_server(|| {
fn_service(|io: Io| async move { fn_service(|io: Io| async move {
io.send(&BytesCodec, Bytes::from_static(b"test")) io.send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
Ok::<_, io::Error>(()) Ok::<_, io::Error>(())
@ -111,7 +111,7 @@ async fn test_rustls_uri() {
let srv = test_server(|| { let srv = test_server(|| {
fn_service(|io: Io| async move { fn_service(|io: Io| async move {
io.send(&BytesCodec, Bytes::from_static(b"test")) io.send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
Ok::<_, io::Error>(()) Ok::<_, io::Error>(())

View file

@ -463,7 +463,7 @@ async fn test_ws_transport() {
if let Some(item) = if let Some(item) =
io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? io.recv(&BytesCodec).await.map_err(|e| e.into_inner())?
{ {
io.send(&BytesCodec, item.freeze()).await.unwrap() io.send(item.freeze(), &BytesCodec).await.unwrap()
} else { } else {
break; break;
} }
@ -478,7 +478,7 @@ async fn test_ws_transport() {
let io = srv.wss().await.unwrap().into_inner().0; let io = srv.wss().await.unwrap().into_inner().0;
let codec = ws::Codec::default().client_mode(); let codec = ws::Codec::default().client_mode();
io.send(&codec, ws::Message::Binary(Bytes::from_static(b"text"))) io.send(ws::Message::Binary(Bytes::from_static(b"text")), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await.unwrap().unwrap(); let item = io.recv(&codec).await.unwrap().unwrap();

View file

@ -95,7 +95,7 @@ async fn test_simple() {
assert_eq!(conn.response().status(), StatusCode::SWITCHING_PROTOCOLS); assert_eq!(conn.response().status(), StatusCode::SWITCHING_PROTOCOLS);
let (io, codec, _) = conn.into_inner(); let (io, codec, _) = conn.into_inner();
io.send(&codec, ws::Message::Text(ByteString::from_static("text"))) io.send(ws::Message::Text(ByteString::from_static("text")), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await; let item = io.recv(&codec).await;
@ -104,7 +104,7 @@ async fn test_simple() {
ws::Frame::Text(Bytes::from_static(b"text")) ws::Frame::Text(Bytes::from_static(b"text"))
); );
io.send(&codec, ws::Message::Binary("text".into())) io.send(ws::Message::Binary("text".into()), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await; let item = io.recv(&codec).await;
@ -113,7 +113,7 @@ async fn test_simple() {
ws::Frame::Binary(Bytes::from_static(&b"text"[..])) ws::Frame::Binary(Bytes::from_static(&b"text"[..]))
); );
io.send(&codec, ws::Message::Ping("text".into())) io.send(ws::Message::Ping("text".into()), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await; let item = io.recv(&codec).await;
@ -123,8 +123,8 @@ async fn test_simple() {
); );
io.send( io.send(
&codec,
ws::Message::Continuation(ws::Item::FirstText("text".into())), ws::Message::Continuation(ws::Item::FirstText("text".into())),
&codec,
) )
.await .await
.unwrap(); .unwrap();
@ -136,22 +136,22 @@ async fn test_simple() {
assert!(io assert!(io
.send( .send(
&codec,
ws::Message::Continuation(ws::Item::FirstText("text".into())), ws::Message::Continuation(ws::Item::FirstText("text".into())),
&codec,
) )
.await .await
.is_err()); .is_err());
assert!(io assert!(io
.send( .send(
&codec,
ws::Message::Continuation(ws::Item::FirstBinary("text".into())), ws::Message::Continuation(ws::Item::FirstBinary("text".into())),
&codec,
) )
.await .await
.is_err()); .is_err());
io.send( io.send(
&codec,
ws::Message::Continuation(ws::Item::Continue("text".into())), ws::Message::Continuation(ws::Item::Continue("text".into())),
&codec,
) )
.await .await
.unwrap(); .unwrap();
@ -162,8 +162,8 @@ async fn test_simple() {
); );
io.send( io.send(
&codec,
ws::Message::Continuation(ws::Item::Last("text".into())), ws::Message::Continuation(ws::Item::Last("text".into())),
&codec,
) )
.await .await
.unwrap(); .unwrap();
@ -175,23 +175,23 @@ async fn test_simple() {
assert!(io assert!(io
.send( .send(
&codec,
ws::Message::Continuation(ws::Item::Continue("text".into())), ws::Message::Continuation(ws::Item::Continue("text".into())),
&codec,
) )
.await .await
.is_err()); .is_err());
assert!(io assert!(io
.send( .send(
&codec,
ws::Message::Continuation(ws::Item::Last("text".into())), ws::Message::Continuation(ws::Item::Last("text".into())),
&codec,
) )
.await .await
.is_err()); .is_err());
io.send( io.send(
&codec,
ws::Message::Continuation(ws::Item::FirstBinary(Bytes::from_static(b"bin"))), ws::Message::Continuation(ws::Item::FirstBinary(Bytes::from_static(b"bin"))),
&codec,
) )
.await .await
.unwrap(); .unwrap();
@ -202,8 +202,8 @@ async fn test_simple() {
); );
io.send( io.send(
&codec,
ws::Message::Continuation(ws::Item::Continue("text".into())), ws::Message::Continuation(ws::Item::Continue("text".into())),
&codec,
) )
.await .await
.unwrap(); .unwrap();
@ -214,8 +214,8 @@ async fn test_simple() {
); );
io.send( io.send(
&codec,
ws::Message::Continuation(ws::Item::Last("text".into())), ws::Message::Continuation(ws::Item::Last("text".into())),
&codec,
) )
.await .await
.unwrap(); .unwrap();
@ -226,8 +226,8 @@ async fn test_simple() {
); );
io.send( io.send(
&codec,
ws::Message::Close(Some(ws::CloseCode::Normal.into())), ws::Message::Close(Some(ws::CloseCode::Normal.into())),
&codec,
) )
.await .await
.unwrap(); .unwrap();
@ -265,7 +265,7 @@ async fn test_transport() {
if let Some(item) = if let Some(item) =
io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? io.recv(&BytesCodec).await.map_err(|e| e.into_inner())?
{ {
io.send(&BytesCodec, item.freeze()).await.unwrap() io.send(item.freeze(), &BytesCodec).await.unwrap()
} else { } else {
break; break;
} }
@ -280,7 +280,7 @@ async fn test_transport() {
let io = srv.ws().await.unwrap().into_inner().0; let io = srv.ws().await.unwrap().into_inner().0;
let codec = ws::Codec::default().client_mode(); let codec = ws::Codec::default().client_mode();
io.send(&codec, ws::Message::Binary(Bytes::from_static(b"text"))) io.send(ws::Message::Binary(Bytes::from_static(b"text")), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await.unwrap().unwrap(); let item = io.recv(&codec).await.unwrap().unwrap();

View file

@ -52,27 +52,27 @@ async fn test_simple() {
// client service // client service
let (io, codec, _) = srv.ws().await.unwrap().into_inner(); let (io, codec, _) = srv.ws().await.unwrap().into_inner();
io.send(&codec, ws::Message::Text(ByteString::from_static("text"))) io.send(ws::Message::Text(ByteString::from_static("text")), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await.unwrap().unwrap(); let item = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text")));
io.send(&codec, ws::Message::Binary("text".into())) io.send(ws::Message::Binary("text".into()), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await.unwrap().unwrap(); let item = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text")));
io.send(&codec, ws::Message::Ping("text".into())) io.send(ws::Message::Ping("text".into()), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await.unwrap().unwrap(); let item = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Pong("text".to_string().into())); assert_eq!(item, ws::Frame::Pong("text".to_string().into()));
io.send( io.send(
&codec,
ws::Message::Close(Some(ws::CloseCode::Normal.into())), ws::Message::Close(Some(ws::CloseCode::Normal.into())),
&codec,
) )
.await .await
.unwrap(); .unwrap();
@ -110,7 +110,7 @@ async fn test_transport() {
// client service // client service
let io = srv.ws().await.unwrap().into_transport().await; let io = srv.ws().await.unwrap().into_transport().await;
io.send(&BytesCodec, Bytes::from_static(b"text")) io.send(Bytes::from_static(b"text"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&BytesCodec).await.unwrap().unwrap(); let item = io.recv(&BytesCodec).await.unwrap().unwrap();

View file

@ -77,7 +77,7 @@ fn test_run() {
.disable_signals() .disable_signals()
.bind("test", addr, move |_| { .bind("test", addr, move |_| {
fn_service(|io: Io| async move { fn_service(|io: Io| async move {
io.send(&BytesCodec, Bytes::from_static(b"test")) io.send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
Ok::<_, ()>(()) Ok::<_, ()>(())

View file

@ -38,27 +38,27 @@ async fn web_ws() {
// client service // client service
let (io, codec, _) = srv.ws().await.unwrap().into_inner(); let (io, codec, _) = srv.ws().await.unwrap().into_inner();
io.send(&codec, ws::Message::Text(ByteString::from_static("text"))) io.send(ws::Message::Text(ByteString::from_static("text")), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await.unwrap().unwrap(); let item = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text")));
io.send(&codec, ws::Message::Binary("text".into())) io.send(ws::Message::Binary("text".into()), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await.unwrap().unwrap(); let item = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text")));
io.send(&codec, ws::Message::Ping("text".into())) io.send(ws::Message::Ping("text".into()), &codec)
.await .await
.unwrap(); .unwrap();
let item = io.recv(&codec).await.unwrap().unwrap(); let item = io.recv(&codec).await.unwrap().unwrap();
assert_eq!(item, ws::Frame::Pong("text".to_string().into())); assert_eq!(item, ws::Frame::Pong("text".to_string().into()));
io.send( io.send(
&codec,
ws::Message::Close(Some(ws::CloseCode::Normal.into())), ws::Message::Close(Some(ws::CloseCode::Normal.into())),
&codec,
) )
.await .await
.unwrap(); .unwrap();