use MaybeUninit type for header parsing. (#107)

This commit is contained in:
fakeshadow 2022-04-26 09:00:21 +08:00 committed by GitHub
parent fb6c16fcb2
commit 114a7b6dba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -212,17 +212,15 @@ impl MessageType for Request {
#[allow(clippy::uninit_assumed_init)]
fn decode(src: &mut BytesMut) -> Result<Option<(Self, PayloadType)>, ParseError> {
// Unsafe: we read this data only after httparse parses headers into.
// performance bump for pipeline benchmarks.
let mut headers: [HeaderIndex; MAX_HEADERS] =
unsafe { MaybeUninit::uninit().assume_init() };
let mut headers: [MaybeUninit<HeaderIndex>; MAX_HEADERS] = uninit_array();
let (len, method, uri, ver, h_len) = {
let mut parsed: [httparse::Header<'_>; MAX_HEADERS] =
unsafe { MaybeUninit::uninit().assume_init() };
let (len, method, uri, ver, headers) = {
let mut parsed: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] =
uninit_array();
let mut req = httparse::Request::new(&mut parsed);
match req.parse(src)? {
let mut req = httparse::Request::new(&mut []);
match req.parse_with_uninit_headers(src, &mut parsed)? {
httparse::Status::Complete(len) => {
let method = Method::from_bytes(req.method.unwrap().as_bytes())
.map_err(|_| ParseError::Method)?;
@ -232,9 +230,14 @@ impl MessageType for Request {
} else {
Version::HTTP_10
};
HeaderIndex::record(src, req.headers, &mut headers);
(len, method, uri, version, req.headers.len())
(
len,
method,
uri,
version,
HeaderIndex::record(src, req.headers, &mut headers),
)
}
httparse::Status::Partial => {
if src.len() >= MAX_BUFFER_SIZE {
@ -249,7 +252,7 @@ impl MessageType for Request {
let mut msg = Request::new();
// convert headers
let length = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?;
let length = msg.set_headers(&src.split_to(len).freeze(), headers)?;
// payload decoder
let decoder = match length {
@ -293,17 +296,18 @@ impl MessageType for ResponseHead {
#[allow(clippy::uninit_assumed_init)]
fn decode(src: &mut BytesMut) -> Result<Option<(Self, PayloadType)>, ParseError> {
// Unsafe: we read this data only after httparse parses headers into.
// performance bump for pipeline benchmarks.
let mut headers: [HeaderIndex; MAX_HEADERS] =
unsafe { MaybeUninit::uninit().assume_init() };
let mut headers: [MaybeUninit<HeaderIndex>; MAX_HEADERS] = uninit_array();
let (len, ver, status, h_len) = {
let mut parsed: [httparse::Header<'_>; MAX_HEADERS] =
unsafe { MaybeUninit::uninit().assume_init() };
let (len, ver, status, headers) = {
let mut parsed: [MaybeUninit<httparse::Header<'_>>; MAX_HEADERS] =
uninit_array();
let mut res = httparse::Response::new(&mut parsed);
match res.parse(src)? {
let mut res = httparse::Response::new(&mut []);
match httparse::ParserConfig::default().parse_response_with_uninit_headers(
&mut res,
src,
&mut parsed,
)? {
httparse::Status::Complete(len) => {
let version = if res.version.unwrap() == 1 {
Version::HTTP_11
@ -312,9 +316,13 @@ impl MessageType for ResponseHead {
};
let status = StatusCode::from_u16(res.code.unwrap())
.map_err(|_| ParseError::Status)?;
HeaderIndex::record(src, res.headers, &mut headers);
(len, version, status, res.headers.len())
(
len,
version,
status,
HeaderIndex::record(src, res.headers, &mut headers),
)
}
httparse::Status::Partial => {
return if src.len() >= MAX_BUFFER_SIZE {
@ -331,7 +339,7 @@ impl MessageType for ResponseHead {
msg.version = ver;
// convert headers
let length = msg.set_headers(&src.split_to(len).freeze(), &headers[..h_len])?;
let length = msg.set_headers(&src.split_to(len).freeze(), headers)?;
// message payload
let decoder = if let PayloadLength::Payload(pl) = length {
@ -360,19 +368,35 @@ pub(super) struct HeaderIndex {
}
impl HeaderIndex {
pub(super) fn record(
pub(super) fn record<'a>(
bytes: &[u8],
headers: &[httparse::Header<'_>],
indices: &mut [HeaderIndex],
) {
indices: &'a mut [MaybeUninit<HeaderIndex>],
) -> &'a [HeaderIndex] {
let bytes_ptr = bytes.as_ptr() as usize;
for (header, indices) in headers.iter().zip(indices.iter_mut()) {
let name_start = header.name.as_ptr() as usize - bytes_ptr;
let name_end = name_start + header.name.len();
indices.name = (name_start, name_end);
let value_start = header.value.as_ptr() as usize - bytes_ptr;
let value_end = value_start + header.value.len();
indices.value = (value_start, value_end);
let init_len = headers
.iter()
.zip(indices.iter_mut())
.map(|(header, indices)| {
let name_start = header.name.as_ptr() as usize - bytes_ptr;
let name_end = name_start + header.name.len();
let value_start = header.value.as_ptr() as usize - bytes_ptr;
let value_end = value_start + header.value.len();
indices.write(HeaderIndex {
name: (name_start, name_end),
value: (value_start, value_end),
})
})
.count();
// SAFETY:
//
// The total initialized items are counted by iterator.
unsafe {
&*(&indices[..init_len] as *const [MaybeUninit<HeaderIndex>]
as *const [HeaderIndex])
}
}
}
@ -671,6 +695,11 @@ impl ChunkedState {
}
}
fn uninit_array<T, const LEN: usize>() -> [MaybeUninit<T>; LEN] {
// SAFETY: An uninitialized `[MaybeUninit<_>; LEN]` is valid.
unsafe { MaybeUninit::uninit().assume_init() }
}
#[cfg(test)]
mod tests {
use super::*;