Fix handling of connection header #370 (#373)

This commit is contained in:
Nikolay Kim 2024-06-22 13:50:17 +02:00 committed by GitHub
parent e0b5284fdd
commit 0255df9f16
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 179 additions and 49 deletions

View file

@ -1,5 +1,11 @@
# Changes
## [2.0.2] - 2024-06-22
* web: Cleanup http request in cache
* http: Fix handling of connection header
## [2.0.1] - 2024-05-29
* http: Fix handling payload timer after payload got consumed

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "2.0.1"
version = "2.0.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"

View file

@ -1,10 +1,9 @@
use std::{future::Future, io};
use std::{fmt, future::Future, io};
use crate::http::message::CurrentIo;
use crate::http::{body::Body, h1::Codec, Request, Response, ResponseError};
use crate::io::{Filter, Io, IoBoxed};
#[derive(Debug)]
pub enum Control<F, Err> {
/// New request is loaded
NewRequest(NewRequest),
@ -108,6 +107,29 @@ impl<F, Err> Control<F, Err> {
}
}
impl<F, Err> fmt::Debug for Control<F, Err>
where
Err: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Control::NewRequest(msg) => {
f.debug_tuple("Control::NewRequest").field(msg).finish()
}
Control::Upgrade(msg) => f.debug_tuple("Control::Upgrade").field(msg).finish(),
Control::Expect(msg) => f.debug_tuple("Control::Expect").field(msg).finish(),
Control::Closed(msg) => f.debug_tuple("Control::Closed").field(msg).finish(),
Control::Error(msg) => f.debug_tuple("Control::Error").field(msg).finish(),
Control::ProtocolError(msg) => {
f.debug_tuple("Control::ProtocolError").field(msg).finish()
}
Control::PeerGone(msg) => {
f.debug_tuple("Control::PeerGone").field(msg).finish()
}
}
}
}
#[derive(Debug)]
pub struct NewRequest(Request);
@ -164,7 +186,6 @@ impl NewRequest {
}
}
#[derive(Debug)]
pub struct Upgrade<F> {
req: Request,
io: Io<F>,
@ -244,6 +265,16 @@ impl<F: Filter> Upgrade<F> {
}
}
impl<F> fmt::Debug for Upgrade<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Upgrade")
.field("req", &self.req)
.field("io", &self.io)
.field("codec", &self.codec)
.finish()
}
}
/// Connection closed message
#[derive(Debug)]
pub struct Closed;

View file

@ -153,16 +153,8 @@ pub(super) trait MessageType: Sized {
}
// connection keep-alive state
header::CONNECTION => {
ka = if let Ok(conn) = value.to_str().map(|conn| conn.trim()) {
if conn.eq_ignore_ascii_case("keep-alive") {
Some(ConnectionType::KeepAlive)
} else if conn.eq_ignore_ascii_case("close") {
Some(ConnectionType::Close)
} else if conn.eq_ignore_ascii_case("upgrade") {
Some(ConnectionType::Upgrade)
} else {
None
}
ka = if let Ok(val) = value.to_str() {
connection_type(val)
} else {
None
};
@ -398,6 +390,59 @@ impl MessageType for ResponseHead {
}
}
const S_KEEP_ALIVE: &str = "keep-alive";
const S_CLOSE: &str = "close";
const S_UPGRADE: &str = "upgrade";
fn connection_type(val: &str) -> Option<ConnectionType> {
let l = val.len();
let bytes = val.as_bytes();
for i in 0..bytes.len() {
if i >= S_CLOSE.len() {
return None;
}
let result = match bytes[i] {
b'k' | b'K' => {
let pos = i + S_KEEP_ALIVE.len();
if l >= pos && val[i..pos].eq_ignore_ascii_case(S_KEEP_ALIVE) {
Some((ConnectionType::KeepAlive, pos))
} else {
None
}
}
b'c' | b'C' => {
let pos = i + S_CLOSE.len();
if l >= pos && val[i..pos].eq_ignore_ascii_case(S_CLOSE) {
Some((ConnectionType::Close, pos))
} else {
None
}
}
b'u' | b'U' => {
let pos = i + S_UPGRADE.len();
if l >= pos && val[i..pos].eq_ignore_ascii_case(S_UPGRADE) {
Some((ConnectionType::Upgrade, pos))
} else {
None
}
}
_ => continue,
};
if let Some((t, pos)) = result {
let next = pos + 1;
if val.len() > next {
if matches!(bytes[next], b' ' | b',' | b'\r' | b'\n') {
return Some(t);
}
} else {
return Some(t);
}
}
}
None
}
#[derive(Clone, Copy)]
pub(super) struct HeaderIndex {
pub(super) name: (usize, usize),
@ -802,6 +847,22 @@ mod tests {
}
}
#[test]
fn test_connection_type() {
for s in &["Close", "Close\r\n", "close,", "close "] {
assert_eq!(connection_type(s), Some(ConnectionType::Close));
}
for s in &["upgrade", "upGrade\r\n", "upgrade,", "upgrade "] {
assert_eq!(connection_type(s), Some(ConnectionType::Upgrade));
}
for s in &["keep-alive", "keep-Alive\r\n", "keep-alive,", "Keep-alive "] {
assert_eq!(connection_type(s), Some(ConnectionType::KeepAlive));
}
for s in &["keep-aliv", "clos\r\n", "clos", "upgrad"] {
assert_eq!(connection_type(s), None);
}
}
#[test]
fn test_parse_partial() {
let mut buf = BytesMut::from("PUT /test HTTP/1");

View file

@ -1,9 +1,8 @@
use std::{cell::Ref, cell::RefCell, cell::RefMut, net, rc::Rc};
use std::{cell::Ref, cell::RefCell, cell::RefMut, fmt, net, rc::Rc};
use bitflags::bitflags;
use crate::http::header::HeaderMap;
use crate::http::{h1::Codec, Method, StatusCode, Uri, Version};
use crate::http::{h1::Codec, header::HeaderMap, Method, StatusCode, Uri, Version};
use crate::io::{types, IoBoxed, IoRef};
use crate::util::Extensions;
@ -29,7 +28,7 @@ bitflags! {
}
}
pub(crate) trait Head: Default + 'static {
pub(crate) trait Head: Default + 'static + fmt::Debug {
fn clear(&mut self);
fn with_pool<F, R>(f: F) -> R
@ -211,6 +210,21 @@ impl RequestHead {
pub fn take_io(&self) -> Option<(IoBoxed, Codec)> {
self.io.take()
}
#[doc(hidden)]
pub fn remove_io(&mut self) {
self.io = CurrentIo::None;
}
pub(crate) fn take_io_rc(
&self,
) -> Option<Rc<(IoRef, RefCell<Option<(IoBoxed, Codec)>>)>> {
if let CurrentIo::Io(ref r) = self.io {
Some(r.clone())
} else {
None
}
}
}
#[derive(Debug)]
@ -371,8 +385,8 @@ impl ResponseHead {
}
}
pub(crate) fn set_io(&mut self, head: &RequestHead) {
self.io = head.io.clone();
pub(crate) fn set_io(&mut self, io: Rc<(IoRef, RefCell<Option<(IoBoxed, Codec)>>)>) {
self.io = CurrentIo::Io(io)
}
}
@ -398,6 +412,7 @@ impl Head for ResponseHead {
}
}
#[derive(Debug)]
pub(crate) struct Message<T: Head> {
head: Rc<T>,
}
@ -434,7 +449,15 @@ impl<T: Head> std::ops::DerefMut for Message<T> {
impl<T: Head> Drop for Message<T> {
fn drop(&mut self) {
T::with_pool(|p| p.release(self.head.clone()));
T::with_pool(|pool| {
let v = &mut pool.0.borrow_mut();
if v.len() < 128 {
Rc::get_mut(&mut self.head)
.expect("Multiple copies exist")
.clear();
v.push(self.head.clone());
}
});
}
}
@ -452,24 +475,14 @@ impl<T: Head> MessagePool<T> {
/// Get message from the pool
#[inline]
fn get_message(&self) -> Message<T> {
if let Some(mut msg) = self.0.borrow_mut().pop() {
if let Some(r) = Rc::get_mut(&mut msg) {
r.clear();
let head = if let Some(mut msg) = self.0.borrow_mut().pop() {
if let Some(msg) = Rc::get_mut(&mut msg) {
msg.clear();
}
Message { head: msg }
msg
} else {
Message {
head: Rc::new(T::default()),
}
}
}
#[inline]
/// Release request instance
fn release(&self, msg: Rc<T>) {
let v = &mut self.0.borrow_mut();
if v.len() < 128 {
v.push(msg);
}
Rc::new(T::default())
};
Message { head }
}
}

View file

@ -117,12 +117,7 @@ impl Request {
/// Check if request requires connection upgrade
#[inline]
pub fn upgrade(&self) -> bool {
if let Some(conn) = self.head().headers.get(header::CONNECTION) {
if let Ok(s) = conn.to_str() {
return s.to_lowercase().contains("upgrade");
}
}
self.head().method == Method::CONNECT
self.head().upgrade() || self.head().method == Method::CONNECT
}
/// Io reference for current connection

View file

@ -245,9 +245,10 @@ impl HttpMessage for HttpRequest {
impl Drop for HttpRequest {
fn drop(&mut self) {
if Rc::strong_count(&self.0) == 1 {
let v = &mut self.0.pool.0.borrow_mut();
if let Some(inner) = Rc::get_mut(&mut self.0) {
let v = &mut inner.pool.0.borrow_mut();
if v.len() < 128 {
inner.head.remove_io();
self.extensions_mut().clear();
v.push(self.0.clone());
}

View file

@ -125,9 +125,8 @@ impl WebResponse {
impl From<WebResponse> for Response<Body> {
fn from(mut res: WebResponse) -> Response<Body> {
let head = res.response.head_mut();
if res.request.head().upgrade() {
head.set_io(res.request.head());
if let Some(io) = res.request.head().take_io_rc() {
res.response.head_mut().set_io(io);
}
res.response
}

View file

@ -761,6 +761,7 @@ enum HttpVer {
}
#[derive(Clone)]
#[allow(clippy::large_enum_variant)]
enum StreamType {
Tcp,
#[cfg(feature = "openssl")]

View file

@ -86,6 +86,29 @@ async fn web_no_ws() {
));
}
#[ntex::test]
async fn web_no_ws_2() {
let srv = test::server(|| {
App::new().service(
web::resource("/")
.route(web::to(|| async { HttpResponse::Ok().body("Hello world") })),
)
});
let mut response = srv
.get("/")
.no_decompress()
.header("test", "h2c")
.header("connection", "upgrade, test")
.set_connection_type(ntex::http::ConnectionType::Upgrade)
.send()
.await
.unwrap();
assert!(response.status().is_success());
let body = response.body().await.unwrap();
assert_eq!(body, b"Hello world");
}
#[ntex::test]
async fn web_ws_client() {
let srv = test::server(|| {