simplify framed dispatcher states

This commit is contained in:
Nikolay Kim 2020-04-13 09:48:46 +06:00
parent f52f7e616c
commit d7699b74d7
3 changed files with 32 additions and 28 deletions

View file

@ -126,13 +126,13 @@ impl<T, U> Framed<T, U> {
#[inline]
/// Get read buffer.
pub fn read_buf_mut(&mut self) -> &mut BytesMut {
pub fn read_buf(&mut self) -> &mut BytesMut {
&mut self.read_buf
}
#[inline]
/// Get write buffer.
pub fn write_buf_mut(&mut self) -> &mut BytesMut {
pub fn write_buf(&mut self) -> &mut BytesMut {
&mut self.write_buf
}
@ -556,6 +556,8 @@ mod tests {
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
Pin::new(&mut server).start_send(data).unwrap();
assert_eq!(client.read_any(), b"".as_ref());
assert_eq!(server.read_buf(), b"".as_ref());
assert_eq!(server.write_buf(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await

View file

@ -1,6 +1,6 @@
# Changes
## [0.1.9] - 2020-04-xx
## [0.1.9] - 2020-04-13
* ntex::util: Refcator framed dispatcher

View file

@ -180,8 +180,7 @@ where
enum FramedState<S: Service, U: Encoder + Decoder> {
Processing,
Error(DispatcherError<S::Error, U>),
FlushAndStop,
FlushAndStop(Option<DispatcherError<S::Error, U>>),
Shutdown(Option<DispatcherError<S::Error, U>>),
ShutdownIo(Delay, Option<Result<(), DispatcherError<S::Error, U>>>),
}
@ -192,15 +191,6 @@ enum PollResult {
Pending,
}
impl<S: Service, U: Encoder + Decoder> FramedState<S, U> {
fn take_error(&mut self) -> DispatcherError<S::Error, U> {
match std::mem::replace(self, FramedState::Processing) {
FramedState::Error(err) => err,
_ => panic!(),
}
}
}
struct InnerDispatcher<S, T, U, Out>
where
S: Service<Request = Request<U>, Response = Option<Response<U>>>,
@ -263,7 +253,8 @@ where
}
Poll::Pending => return PollResult::Pending,
Poll::Ready(Err(err)) => {
self.state = FramedState::Error(DispatcherError::Service(err));
self.state =
FramedState::FlushAndStop(Some(DispatcherError::Service(err)));
return PollResult::Continue;
}
}
@ -285,7 +276,9 @@ where
continue;
}
Poll::Ready(Some(Err(err))) => {
self.state = FramedState::Error(DispatcherError::Service(err));
self.state = FramedState::FlushAndStop(Some(
DispatcherError::Service(err),
));
return PollResult::Continue;
}
Poll::Ready(None) | Poll::Pending => {}
@ -304,7 +297,7 @@ where
}
Poll::Ready(None) => {
let _ = self.sink.take();
self.state = FramedState::FlushAndStop;
self.state = FramedState::FlushAndStop(None);
return PollResult::Continue;
}
Poll::Pending => (),
@ -346,16 +339,7 @@ where
return Poll::Pending;
}
}
FramedState::Error(_) => {
// flush write buffer
if !self.framed.is_write_buf_empty() {
if let Poll::Pending = self.framed.flush(cx) {
return Poll::Pending;
}
}
self.state = FramedState::Shutdown(Some(self.state.take_error()));
}
FramedState::FlushAndStop => {
FramedState::FlushAndStop(ref mut err) => {
// drain service responses
match Pin::new(&mut self.rx).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => {
@ -385,7 +369,7 @@ where
Poll::Ready(_) => (),
}
};
self.state = FramedState::Shutdown(None);
self.state = FramedState::Shutdown(err.take());
}
FramedState::Shutdown(ref mut err) => {
return if self.service.poll_shutdown(cx, err.is_some()).is_ready() {
@ -440,7 +424,9 @@ where
#[cfg(test)]
mod tests {
use bytes::{Bytes, BytesMut};
use derive_more::Display;
use futures::future::ok;
use std::io;
use super::*;
use crate::channel::mpsc;
@ -448,6 +434,22 @@ mod tests {
use crate::rt::time::delay_for;
use crate::testing::Io;
#[test]
fn test_err() {
#[derive(Debug, Display)]
struct TestError;
type T = DispatcherError<TestError, BytesCodec>;
let err = T::Encoder(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatcherError::Encoder"));
assert!(format!("{}", err).contains("Custom"));
let err = T::Decoder(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatcherError::Decoder"));
assert!(format!("{}", err).contains("Custom"));
let err = T::from(TestError);
assert!(format!("{:?}", err).contains("DispatcherError::Service"));
assert_eq!(format!("{}", err), "TestError");
}
#[ntex_rt::test]
async fn test_basic() {
let (client, server) = Io::create();