diff --git a/ntex-io/src/framed.rs b/ntex-io/src/framed.rs index 28e4db6b..ba7f3911 100644 --- a/ntex-io/src/framed.rs +++ b/ntex-io/src/framed.rs @@ -85,7 +85,7 @@ where &self, item: ::Item, ) -> Result<(), Either> { - self.io.send(&self.codec, item).await + self.io.send(item, &self.codec).await } } diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index f3e35e34..32247284 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -459,22 +459,18 @@ impl Io { /// Encode item, send to a peer pub async fn send( &self, - codec: &U, item: U::Item, + codec: &U, ) -> Result<(), Either> where U: Encoder, { - let filter = self.filter(); - let mut buf = filter - .get_write_buf() - .unwrap_or_else(|| self.memory_pool().get_write_buf()); - codec.encode(item, &mut buf).map_err(Either::Left)?; - filter.release_write_buf(buf).map_err(Either::Right)?; + self.encode(item, codec).map_err(Either::Left)?; poll_fn(|cx| self.poll_flush(cx, true)) .await .map_err(Either::Right)?; + Ok(()) } diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index 7fbeaccb..bbeba13d 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -145,8 +145,12 @@ impl IoRef { let mut buf = filter .get_write_buf() .unwrap_or_else(|| self.memory_pool().get_write_buf()); + let is_write_sleep = buf.is_empty(); let result = f(&mut buf); + if is_write_sleep { + self.0.write_task.wake(); + } filter.release_write_buf(buf)?; Ok(result) } @@ -177,29 +181,28 @@ impl IoRef { let flags = self.0.flags.get(); if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { - let filter = self.0.filter.get(); - let mut buf = filter - .get_write_buf() - .unwrap_or_else(|| self.memory_pool().get_write_buf()); - let is_write_sleep = buf.is_empty(); - let (hw, lw) = self.memory_pool().write_params().unpack(); + self.with_write_buf(|buf| { + let (hw, lw) = self.memory_pool().write_params().unpack(); - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } + // make sure we've got room + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); + } - // encode item and wake write task - codec.encode(item, &mut buf)?; - if is_write_sleep { - self.0.write_task.wake(); - } - if let Err(err) = filter.release_write_buf(buf) { - self.0.set_error(Some(err)); - } + // encode item and wake write task + codec.encode(item, buf) + }) + .map_or_else( + |err| { + self.0.set_error(Some(err)); + Ok(()) + }, + |item| item, + ) + } else { + Ok(()) } - Ok(()) } #[inline] @@ -221,31 +224,15 @@ impl IoRef { #[inline] /// Write bytes to a buffer and wake up write task - /// - /// Returns write buffer state, false is returned if write buffer if full. - pub fn write(&self, src: &[u8]) -> Result { + pub fn write(&self, src: &[u8]) -> io::Result<()> { let flags = self.0.flags.get(); if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { - let filter = self.0.filter.get(); - let mut buf = filter - .get_write_buf() - .unwrap_or_else(|| self.memory_pool().get_write_buf()); - let is_write_sleep = buf.is_empty(); - - // write and wake write task - buf.extend_from_slice(src); - let result = buf.len() < self.memory_pool().write_params_high(); - if is_write_sleep { - self.0.write_task.wake(); - } - - if let Err(err) = filter.release_write_buf(buf) { - self.0.set_error(Some(err)); - } - Ok(result) + self.with_write_buf(|buf| { + buf.extend_from_slice(src); + }) } else { - Ok(true) + Ok(()) } } } @@ -318,14 +305,14 @@ mod tests { client.remote_buffer_cap(1024); let state = Io::new(server); state - .send(&BytesCodec, Bytes::from_static(b"test")) + .send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); let buf = client.read().await.unwrap(); assert_eq!(buf, Bytes::from_static(b"test")); client.write_error(io::Error::new(io::ErrorKind::Other, "err")); - let res = state.send(&BytesCodec, Bytes::from_static(b"test")).await; + let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await; assert!(res.is_err()); assert!(state.flags().contains(Flags::IO_ERR)); @@ -496,7 +483,7 @@ mod tests { assert_eq!(msg, Bytes::from_static(BIN)); state - .send(&BytesCodec, Bytes::from_static(b"test")) + .send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); let buf = client.read().await.unwrap(); @@ -541,7 +528,7 @@ mod tests { assert_eq!(msg, Bytes::from_static(BIN)); state - .send(&BytesCodec, Bytes::from_static(b"test")) + .send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); let buf = client.read().await.unwrap(); diff --git a/ntex-tls/examples/client.rs b/ntex-tls/examples/client.rs index 54347fd6..f0f92406 100644 --- a/ntex-tls/examples/client.rs +++ b/ntex-tls/examples/client.rs @@ -20,7 +20,7 @@ async fn main() -> io::Result<()> { let io = connector.connect("127.0.0.1:8443").await.unwrap(); println!("Connected to ssl server"); let result = io - .send(&codec::BytesCodec, Bytes::from_static(b"hello")) + .send(Bytes::from_static(b"hello"), &codec::BytesCodec) .await .map_err(Either::into_inner)?; diff --git a/ntex-tls/examples/rustls-server.rs b/ntex-tls/examples/rustls-server.rs index 2c81d351..790fd633 100644 --- a/ntex-tls/examples/rustls-server.rs +++ b/ntex-tls/examples/rustls-server.rs @@ -40,8 +40,8 @@ async fn main() -> io::Result<()> { println!("New client is connected"); io.send( - &codec::BytesCodec, ntex_bytes::Bytes::from_static(b"Welcome!\n"), + &codec::BytesCodec, ) .await .map_err(Either::into_inner)?; @@ -50,7 +50,7 @@ async fn main() -> io::Result<()> { match io.recv(&codec::BytesCodec).await { Ok(Some(msg)) => { println!("Got message: {:?}", msg); - io.send(&codec::BytesCodec, msg.freeze()) + io.send(msg.freeze(), &codec::BytesCodec) .await .map_err(Either::into_inner)?; } diff --git a/ntex-tls/examples/server.rs b/ntex-tls/examples/server.rs index 58bffb14..76e54640 100644 --- a/ntex-tls/examples/server.rs +++ b/ntex-tls/examples/server.rs @@ -32,7 +32,7 @@ async fn main() -> io::Result<()> { match io.recv(&codec::BytesCodec).await { Ok(Some(msg)) => { println!("Got message: {:?}", msg); - io.send(&codec::BytesCodec, msg.freeze()) + io.send(msg.freeze(), &codec::BytesCodec) .await .map_err(Either::into_inner)?; } diff --git a/ntex/src/http/client/h1proto.rs b/ntex/src/http/client/h1proto.rs index 61570fca..b4458420 100644 --- a/ntex/src/http/client/h1proto.rs +++ b/ntex/src/http/client/h1proto.rs @@ -59,7 +59,7 @@ where // send request let codec = h1::ClientCodec::default(); - io.send(&codec, (head, body.size()).into()).await?; + io.send((head, body.size()).into(), &codec).await?; log::trace!("http1 request has been sent"); diff --git a/ntex/src/ws/client.rs b/ntex/src/ws/client.rs index e36c84c2..bbfec674 100644 --- a/ntex/src/ws/client.rs +++ b/ntex/src/ws/client.rs @@ -164,8 +164,8 @@ where // send request and read response let fut = async { io.send( - &codec, (RequestHeadType::Rc(head, Some(headers)), BodySize::None).into(), + &codec, ) .await?; io.recv(&codec) diff --git a/ntex/src/ws/transport.rs b/ntex/src/ws/transport.rs index 5dd98c5c..f1ba9683 100644 --- a/ntex/src/ws/transport.rs +++ b/ntex/src/ws/transport.rs @@ -287,14 +287,14 @@ mod tests { .unwrap(); client - .send(&BytesCodec, Bytes::from_static(b"DATA")) + .send(Bytes::from_static(b"DATA"), &BytesCodec) .await .unwrap(); let res = server.recv(&BytesCodec).await.unwrap().unwrap(); assert_eq!(res, b"DATA".as_ref()); server - .send(&BytesCodec, Bytes::from_static(b"DATA")) + .send(Bytes::from_static(b"DATA"), &BytesCodec) .await .unwrap(); let res = client.recv(&BytesCodec).await.unwrap().unwrap(); diff --git a/ntex/tests/connect.rs b/ntex/tests/connect.rs index 94bf5451..9721c769 100644 --- a/ntex/tests/connect.rs +++ b/ntex/tests/connect.rs @@ -12,7 +12,7 @@ use ntex::util::Bytes; async fn test_string() { let srv = test_server(|| { fn_service(|io: Io| async move { - io.send(&BytesCodec, Bytes::from_static(b"test")) + io.send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); Ok::<_, io::Error>(()) @@ -30,7 +30,7 @@ async fn test_string() { async fn test_rustls_string() { let srv = test_server(|| { fn_service(|io: Io| async move { - io.send(&BytesCodec, Bytes::from_static(b"test")) + io.send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); Ok::<_, io::Error>(()) @@ -47,7 +47,7 @@ async fn test_rustls_string() { async fn test_static_str() { let srv = test_server(|| { fn_service(|io: Io| async move { - io.send(&BytesCodec, Bytes::from_static(b"test")) + io.send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); Ok::<_, io::Error>(()) @@ -69,7 +69,7 @@ async fn test_static_str() { async fn test_new_service() { let srv = test_server(|| { fn_service(|io: Io| async move { - io.send(&BytesCodec, Bytes::from_static(b"test")) + io.send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); Ok::<_, io::Error>(()) @@ -89,7 +89,7 @@ async fn test_uri() { let srv = test_server(|| { fn_service(|io: Io| async move { - io.send(&BytesCodec, Bytes::from_static(b"test")) + io.send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); Ok::<_, io::Error>(()) @@ -111,7 +111,7 @@ async fn test_rustls_uri() { let srv = test_server(|| { fn_service(|io: Io| async move { - io.send(&BytesCodec, Bytes::from_static(b"test")) + io.send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); Ok::<_, io::Error>(()) diff --git a/ntex/tests/http_openssl.rs b/ntex/tests/http_openssl.rs index 1985aa7d..3295f4f0 100644 --- a/ntex/tests/http_openssl.rs +++ b/ntex/tests/http_openssl.rs @@ -463,7 +463,7 @@ async fn test_ws_transport() { if let Some(item) = io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? { - io.send(&BytesCodec, item.freeze()).await.unwrap() + io.send(item.freeze(), &BytesCodec).await.unwrap() } else { break; } @@ -478,7 +478,7 @@ async fn test_ws_transport() { let io = srv.wss().await.unwrap().into_inner().0; let codec = ws::Codec::default().client_mode(); - io.send(&codec, ws::Message::Binary(Bytes::from_static(b"text"))) + io.send(ws::Message::Binary(Bytes::from_static(b"text")), &codec) .await .unwrap(); let item = io.recv(&codec).await.unwrap().unwrap(); diff --git a/ntex/tests/http_ws.rs b/ntex/tests/http_ws.rs index 2506fd51..55845ae3 100644 --- a/ntex/tests/http_ws.rs +++ b/ntex/tests/http_ws.rs @@ -95,7 +95,7 @@ async fn test_simple() { assert_eq!(conn.response().status(), StatusCode::SWITCHING_PROTOCOLS); let (io, codec, _) = conn.into_inner(); - io.send(&codec, ws::Message::Text(ByteString::from_static("text"))) + io.send(ws::Message::Text(ByteString::from_static("text")), &codec) .await .unwrap(); let item = io.recv(&codec).await; @@ -104,7 +104,7 @@ async fn test_simple() { ws::Frame::Text(Bytes::from_static(b"text")) ); - io.send(&codec, ws::Message::Binary("text".into())) + io.send(ws::Message::Binary("text".into()), &codec) .await .unwrap(); let item = io.recv(&codec).await; @@ -113,7 +113,7 @@ async fn test_simple() { ws::Frame::Binary(Bytes::from_static(&b"text"[..])) ); - io.send(&codec, ws::Message::Ping("text".into())) + io.send(ws::Message::Ping("text".into()), &codec) .await .unwrap(); let item = io.recv(&codec).await; @@ -123,8 +123,8 @@ async fn test_simple() { ); io.send( - &codec, ws::Message::Continuation(ws::Item::FirstText("text".into())), + &codec, ) .await .unwrap(); @@ -136,22 +136,22 @@ async fn test_simple() { assert!(io .send( - &codec, ws::Message::Continuation(ws::Item::FirstText("text".into())), + &codec, ) .await .is_err()); assert!(io .send( - &codec, ws::Message::Continuation(ws::Item::FirstBinary("text".into())), + &codec, ) .await .is_err()); io.send( - &codec, ws::Message::Continuation(ws::Item::Continue("text".into())), + &codec, ) .await .unwrap(); @@ -162,8 +162,8 @@ async fn test_simple() { ); io.send( - &codec, ws::Message::Continuation(ws::Item::Last("text".into())), + &codec, ) .await .unwrap(); @@ -175,23 +175,23 @@ async fn test_simple() { assert!(io .send( - &codec, ws::Message::Continuation(ws::Item::Continue("text".into())), + &codec, ) .await .is_err()); assert!(io .send( - &codec, ws::Message::Continuation(ws::Item::Last("text".into())), + &codec, ) .await .is_err()); io.send( - &codec, ws::Message::Continuation(ws::Item::FirstBinary(Bytes::from_static(b"bin"))), + &codec, ) .await .unwrap(); @@ -202,8 +202,8 @@ async fn test_simple() { ); io.send( - &codec, ws::Message::Continuation(ws::Item::Continue("text".into())), + &codec, ) .await .unwrap(); @@ -214,8 +214,8 @@ async fn test_simple() { ); io.send( - &codec, ws::Message::Continuation(ws::Item::Last("text".into())), + &codec, ) .await .unwrap(); @@ -226,8 +226,8 @@ async fn test_simple() { ); io.send( - &codec, ws::Message::Close(Some(ws::CloseCode::Normal.into())), + &codec, ) .await .unwrap(); @@ -265,7 +265,7 @@ async fn test_transport() { if let Some(item) = io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? { - io.send(&BytesCodec, item.freeze()).await.unwrap() + io.send(item.freeze(), &BytesCodec).await.unwrap() } else { break; } @@ -280,7 +280,7 @@ async fn test_transport() { let io = srv.ws().await.unwrap().into_inner().0; let codec = ws::Codec::default().client_mode(); - io.send(&codec, ws::Message::Binary(Bytes::from_static(b"text"))) + io.send(ws::Message::Binary(Bytes::from_static(b"text")), &codec) .await .unwrap(); let item = io.recv(&codec).await.unwrap().unwrap(); diff --git a/ntex/tests/http_ws_client.rs b/ntex/tests/http_ws_client.rs index 2ebb7d81..95c8558a 100644 --- a/ntex/tests/http_ws_client.rs +++ b/ntex/tests/http_ws_client.rs @@ -52,27 +52,27 @@ async fn test_simple() { // client service let (io, codec, _) = srv.ws().await.unwrap().into_inner(); - io.send(&codec, ws::Message::Text(ByteString::from_static("text"))) + io.send(ws::Message::Text(ByteString::from_static("text")), &codec) .await .unwrap(); let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); - io.send(&codec, ws::Message::Binary("text".into())) + io.send(ws::Message::Binary("text".into()), &codec) .await .unwrap(); let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); - io.send(&codec, ws::Message::Ping("text".into())) + io.send(ws::Message::Ping("text".into()), &codec) .await .unwrap(); let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Pong("text".to_string().into())); io.send( - &codec, ws::Message::Close(Some(ws::CloseCode::Normal.into())), + &codec, ) .await .unwrap(); @@ -110,7 +110,7 @@ async fn test_transport() { // client service let io = srv.ws().await.unwrap().into_transport().await; - io.send(&BytesCodec, Bytes::from_static(b"text")) + io.send(Bytes::from_static(b"text"), &BytesCodec) .await .unwrap(); let item = io.recv(&BytesCodec).await.unwrap().unwrap(); diff --git a/ntex/tests/server.rs b/ntex/tests/server.rs index e57fe3b9..39be0358 100644 --- a/ntex/tests/server.rs +++ b/ntex/tests/server.rs @@ -77,7 +77,7 @@ fn test_run() { .disable_signals() .bind("test", addr, move |_| { fn_service(|io: Io| async move { - io.send(&BytesCodec, Bytes::from_static(b"test")) + io.send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); Ok::<_, ()>(()) diff --git a/ntex/tests/web_ws.rs b/ntex/tests/web_ws.rs index 2df011e2..aa83d98f 100644 --- a/ntex/tests/web_ws.rs +++ b/ntex/tests/web_ws.rs @@ -38,27 +38,27 @@ async fn web_ws() { // client service let (io, codec, _) = srv.ws().await.unwrap().into_inner(); - io.send(&codec, ws::Message::Text(ByteString::from_static("text"))) + io.send(ws::Message::Text(ByteString::from_static("text")), &codec) .await .unwrap(); let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); - io.send(&codec, ws::Message::Binary("text".into())) + io.send(ws::Message::Binary("text".into()), &codec) .await .unwrap(); let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); - io.send(&codec, ws::Message::Ping("text".into())) + io.send(ws::Message::Ping("text".into()), &codec) .await .unwrap(); let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Pong("text".to_string().into())); io.send( - &codec, ws::Message::Close(Some(ws::CloseCode::Normal.into())), + &codec, ) .await .unwrap();