From 3dd7dc68bd7a4029fd1165114d58181389583430 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 5 Oct 2020 12:08:16 +0600 Subject: [PATCH] prep cors release --- ntex-cors/CHANGES.md | 18 +- ntex-cors/Cargo.toml | 6 +- ntex-cors/src/lib.rs | 478 ++++++++++++++++++------------------------- 3 files changed, 206 insertions(+), 296 deletions(-) diff --git a/ntex-cors/CHANGES.md b/ntex-cors/CHANGES.md index 1686b0ca..28475dc6 100644 --- a/ntex-cors/CHANGES.md +++ b/ntex-cors/CHANGES.md @@ -1,19 +1,5 @@ # Changes -## [0.1.0] - 2020-04-xx +## [0.1.0] - 2020-10-05 -* Fork to ntex project - -## [0.2.0] - 2019-12-20 - -* Release - -## [0.2.0-alpha.3] - 2019-12-07 - -* Migrate to actix-web 2.0.0 - -* Bump `derive_more` crate version to 0.99.0 - -## [0.1.0] - 2019-06-15 - -* Move cors middleware to separate crate +* Initial release diff --git a/ntex-cors/Cargo.toml b/ntex-cors/Cargo.toml index 358d35ec..08f376a7 100644 --- a/ntex-cors/Cargo.toml +++ b/ntex-cors/Cargo.toml @@ -2,7 +2,7 @@ name = "ntex-cors" version = "0.1.0" authors = ["ntex contributors "] -description = "Cross-origin resource sharing (CORS) for Actix applications." +description = "Cross-origin resource sharing (CORS) for ntex applications." readme = "README.md" keywords = ["ntex", "web"] homepage = "https://ntex.rs" @@ -16,6 +16,6 @@ name = "ntex_cors" path = "src/lib.rs" [dependencies] -ntex = "0.1.21" +ntex = "0.1.24" derive_more = "0.99.5" -futures = "0.3.4" \ No newline at end of file +futures = "0.3.5" \ No newline at end of file diff --git a/ntex-cors/src/lib.rs b/ntex-cors/src/lib.rs index b8a2d960..cd1f12ae 100644 --- a/ntex-cors/src/lib.rs +++ b/ntex-cors/src/lib.rs @@ -54,19 +54,16 @@ use std::task::{Context, Poll}; use derive_more::Display; use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; use ntex::http::header::{self, HeaderName, HeaderValue}; -use ntex::http::{error::HttpError, Method, RequestHead, StatusCode, Uri}; +use ntex::http::{error::HttpError, HeaderMap, Method, RequestHead, StatusCode, Uri}; use ntex::service::{Service, Transform}; use ntex::web::dev::{WebRequest, WebResponse}; -use ntex::web::HttpResponse; -use ntex::web::{DefaultError, ErrorRenderer, WebResponseError}; +use ntex::web::{DefaultError, ErrorRenderer, HttpResponse, WebResponseError}; /// A set of errors that can occur during processing CORS #[derive(Debug, Display)] pub enum CorsError { /// The HTTP request header `Origin` is required but was not provided - #[display( - fmt = "The HTTP request header `Origin` is required but was not provided" - )] + #[display(fmt = "The HTTP request header `Origin` is required but was not provided")] MissingOrigin, /// The HTTP request header `Origin` could not be parsed correctly. #[display(fmt = "The HTTP request header `Origin` could not be parsed correctly.")] @@ -78,9 +75,7 @@ pub enum CorsError { )] MissingRequestMethod, /// The request header `Access-Control-Request-Method` has an invalid value - #[display( - fmt = "The request header `Access-Control-Request-Method` has an invalid value" - )] + #[display(fmt = "The request header `Access-Control-Request-Method` has an invalid value")] BadRequestMethod, /// The request header `Access-Control-Request-Headers` has an invalid /// value @@ -225,9 +220,7 @@ impl Cors { supports_credentials: false, vary_header: true, }; - CorsFactory { - inner: Rc::new(inner), - } + CorsFactory { inner: Rc::new(inner) } } /// Add an origin that are allowed to make requests. @@ -495,9 +488,7 @@ impl Cors { } if let AllOrSome::Some(ref origins) = cors.origins { - let s = origins - .iter() - .fold(String::new(), |s, v| format!("{}, {}", s, v)); + let s = origins.iter().fold(String::new(), |s, v| format!("{}, {}", s, v)); cors.origins_str = Some(HeaderValue::try_from(&s[2..]).unwrap()); } @@ -510,63 +501,17 @@ impl Cors { ); } - CorsFactory { - inner: Rc::new(cors), - } + CorsFactory { inner: Rc::new(cors) } } } -fn cors<'a>( - parts: &'a mut Option, - err: &Option, -) -> Option<&'a mut Inner> { +fn cors<'a>(parts: &'a mut Option, err: &Option) -> Option<&'a mut Inner> { if err.is_some() { return None; } parts.as_mut() } -/// `Middleware` for Cross-origin resource sharing support -/// -/// The Cors struct contains the settings for CORS requests to be validated and -/// for responses to be generated. -pub struct CorsFactory { - inner: Rc, -} - -impl Transform for CorsFactory -where - S: Service, Response = WebResponse>, - S::Future: 'static, - Err: ErrorRenderer, - Err::Container: From, - CorsError: WebResponseError, -{ - type Request = WebRequest; - type Response = WebResponse; - type Error = S::Error; - type InitError = (); - type Transform = CorsMiddleware; - type Future = Ready>; - - fn new_transform(&self, service: S) -> Self::Future { - ok(CorsMiddleware { - service, - inner: self.inner.clone(), - }) - } -} - -/// `Middleware` for Cross-origin resource sharing support -/// -/// The Cors struct contains the settings for CORS requests to be validated and -/// for responses to be generated. -#[derive(Clone)] -pub struct CorsMiddleware { - service: S, - inner: Rc, -} - struct Inner { methods: HashSet, origins: AllOrSome>, @@ -601,12 +546,12 @@ impl Inner { } } - fn access_control_allow_origin(&self, req: &RequestHead) -> Option { + fn access_control_allow_origin(&self, headers: &HeaderMap) -> Option { match self.origins { AllOrSome::All => { if self.send_wildcard { Some(HeaderValue::from_static("*")) - } else if let Some(origin) = req.headers().get(&header::ORIGIN) { + } else if let Some(origin) = headers.get(&header::ORIGIN) { Some(origin.clone()) } else { None @@ -614,12 +559,10 @@ impl Inner { } AllOrSome::Some(ref origins) => { if let Some(origin) = - req.headers() - .get(&header::ORIGIN) - .filter(|o| match o.to_str() { - Ok(os) => origins.contains(os), - _ => false, - }) + headers.get(&header::ORIGIN).filter(|o| match o.to_str() { + Ok(os) => origins.contains(os), + _ => false, + }) { Some(origin.clone()) } else { @@ -650,9 +593,7 @@ impl Inner { match self.headers { AllOrSome::All => Ok(()), AllOrSome::Some(ref allowed_headers) => { - if let Some(hdr) = - req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) - { + if let Some(hdr) = req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) { if let Ok(headers) = hdr.to_str() { let mut hdrs = HashSet::new(); for hdr in headers.split(',') { @@ -677,6 +618,137 @@ impl Inner { } } } + + fn preflight_check( + &self, + req: &RequestHead, + ) -> Result, CorsError> { + if self.preflight && Method::OPTIONS == req.method { + self.validate_origin(req) + .and_then(|_| self.validate_allowed_method(req)) + .and_then(|_| self.validate_allowed_headers(req))?; + + // allowed headers + let headers = if let Some(headers) = self.headers.as_ref() { + Some( + HeaderValue::try_from( + &headers + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..], + ) + .unwrap(), + ) + } else if let Some(hdr) = req.headers.get(&header::ACCESS_CONTROL_REQUEST_HEADERS) { + Some(hdr.clone()) + } else { + None + }; + + let res = HttpResponse::Ok() + .if_some(self.max_age.as_ref(), |max_age, resp| { + let _ = resp.header( + header::ACCESS_CONTROL_MAX_AGE, + format!("{}", max_age).as_str(), + ); + }) + .if_some(headers, |headers, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); + }) + .if_some(self.access_control_allow_origin(req.headers()), |origin, resp| { + let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + }) + .if_true(self.supports_credentials, |resp| { + resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); + }) + .header( + header::ACCESS_CONTROL_ALLOW_METHODS, + &self + .methods + .iter() + .fold(String::new(), |s, v| s + "," + v.as_str()) + .as_str()[1..], + ) + .finish() + .into_body(); + + Ok(Either::Left(res)) + } else { + if req.headers.contains_key(&header::ORIGIN) { + // Only check requests with a origin header. + self.validate_origin(req)?; + } + Ok(Either::Right(())) + } + } + + fn handle_response(&self, headers: &mut HeaderMap, allowed_origin: Option) { + if let Some(origin) = allowed_origin { + headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + }; + + if let Some(ref expose) = self.expose_hdrs { + headers.insert( + header::ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::try_from(expose.as_str()).unwrap(), + ); + } + if self.supports_credentials { + headers.insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + if self.vary_header { + let value = if let Some(hdr) = headers.get(&header::VARY) { + let mut val: Vec = Vec::with_capacity(hdr.as_bytes().len() + 8); + val.extend(hdr.as_bytes()); + val.extend(b", Origin"); + HeaderValue::try_from(&val[..]).unwrap() + } else { + HeaderValue::from_static("Origin") + }; + headers.insert(header::VARY, value); + } + } +} + +/// `Middleware` for Cross-origin resource sharing support +/// +/// The Cors struct contains the settings for CORS requests to be validated and +/// for responses to be generated. +pub struct CorsFactory { + inner: Rc, +} + +impl Transform for CorsFactory +where + S: Service, Response = WebResponse>, + S::Future: 'static, + Err: ErrorRenderer, + Err::Container: From, + CorsError: WebResponseError, +{ + type Request = WebRequest; + type Response = WebResponse; + type Error = S::Error; + type InitError = (); + type Transform = CorsMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(CorsMiddleware { service, inner: self.inner.clone() }) + } +} + +/// `Middleware` for Cross-origin resource sharing support +/// +/// The Cors struct contains the settings for CORS requests to be validated and +/// for responses to be generated. +#[derive(Clone)] +pub struct CorsMiddleware { + service: S, + inner: Rc, } impl Service for CorsMiddleware @@ -704,125 +776,27 @@ where } fn call(&self, req: WebRequest) -> Self::Future { - if self.inner.preflight && Method::OPTIONS == *req.method() { - if let Err(e) = self - .inner - .validate_origin(req.head()) - .and_then(|_| self.inner.validate_allowed_method(req.head())) - .and_then(|_| self.inner.validate_allowed_headers(req.head())) - { - return Either::Left(ok(req.render_error(e))); - } + match self.inner.preflight_check(req.head()) { + Ok(Either::Left(res)) => Either::Left(ok(req.into_response(res))), + Ok(Either::Right(_)) => { + let inner = self.inner.clone(); + let has_origin = req.headers().contains_key(&header::ORIGIN); + let allowed_origin = inner.access_control_allow_origin(req.headers()); + let fut = self.service.call(req); - // allowed headers - let headers = if let Some(headers) = self.inner.headers.as_ref() { - Some( - HeaderValue::try_from( - &headers - .iter() - .fold(String::new(), |s, v| s + "," + v.as_str()) - .as_str()[1..], - ) - .unwrap(), - ) - } else if let Some(hdr) = - req.headers().get(&header::ACCESS_CONTROL_REQUEST_HEADERS) - { - Some(hdr.clone()) - } else { - None - }; + Either::Right( + async move { + let mut res = fut.await?; - let res = HttpResponse::Ok() - .if_some(self.inner.max_age.as_ref(), |max_age, resp| { - let _ = resp.header( - header::ACCESS_CONTROL_MAX_AGE, - format!("{}", max_age).as_str(), - ); - }) - .if_some(headers, |headers, resp| { - let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); - }) - .if_some( - self.inner.access_control_allow_origin(req.head()), - |origin, resp| { - let _ = resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - }, - ) - .if_true(self.inner.supports_credentials, |resp| { - resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - }) - .header( - header::ACCESS_CONTROL_ALLOW_METHODS, - &self - .inner - .methods - .iter() - .fold(String::new(), |s, v| s + "," + v.as_str()) - .as_str()[1..], - ) - .finish() - .into_body(); - - Either::Left(ok(req.into_response(res))) - } else { - if req.headers().contains_key(&header::ORIGIN) { - // Only check requests with a origin header. - if let Err(e) = self.inner.validate_origin(req.head()) { - return Either::Left(ok(req.render_error(e))); - } - } - - let inner = self.inner.clone(); - let has_origin = req.headers().contains_key(&header::ORIGIN); - let fut = self.service.call(req); - - Either::Right( - async move { - let res = fut.await; - - if has_origin { - let mut res = res?; - if let Some(origin) = - inner.access_control_allow_origin(res.request().head()) - { - res.headers_mut() - .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); - }; - - if let Some(ref expose) = inner.expose_hdrs { - res.headers_mut().insert( - header::ACCESS_CONTROL_EXPOSE_HEADERS, - HeaderValue::try_from(expose.as_str()).unwrap(), - ); - } - if inner.supports_credentials { - res.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - HeaderValue::from_static("true"), - ); - } - if inner.vary_header { - let value = if let Some(hdr) = - res.headers_mut().get(&header::VARY) - { - let mut val: Vec = - Vec::with_capacity(hdr.as_bytes().len() + 8); - val.extend(hdr.as_bytes()); - val.extend(b", Origin"); - HeaderValue::try_from(&val[..]).unwrap() - } else { - HeaderValue::from_static("Origin") - }; - res.headers_mut().insert(header::VARY, value); + if has_origin { + inner.handle_response(res.headers_mut(), allowed_origin); } Ok(res) - } else { - res } - } - .boxed_local(), - ) + .boxed_local(), + ) + } + Err(e) => Either::Left(ok(req.render_error(e))), } } } @@ -842,13 +816,9 @@ mod tests { #[ntex::test] async fn validate_origin_allows_all_origins() { - let cors = Cors::new() - .finish() - .new_transform(test::ok_service()) - .await - .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .to_srv_request(); + let cors = Cors::new().finish().new_transform(test::ok_service()).await.unwrap(); + let req = + TestRequest::with_header("Origin", "https://www.example.com").to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); @@ -856,12 +826,9 @@ mod tests { #[ntex::test] async fn default() { - let cors = Cors::default() - .new_transform(test::ok_service()) - .await - .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .to_srv_request(); + let cors = Cors::default().new_transform(test::ok_service()).await.unwrap(); + let req = + TestRequest::with_header("Origin", "https://www.example.com").to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); @@ -900,27 +867,18 @@ mod tests { let req = TestRequest::with_header("Origin", "https://www.example.com") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .header( - header::ACCESS_CONTROL_REQUEST_HEADERS, - "AUTHORIZATION,ACCEPT", - ) + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT") .method(Method::OPTIONS) .to_srv_request(); let resp = test::call_service(&cors, req).await; assert_eq!( &b"*"[..], - resp.headers() - .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + resp.headers().get(&header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes() ); assert_eq!( &b"3600"[..], - resp.headers() - .get(&header::ACCESS_CONTROL_MAX_AGE) - .unwrap() - .as_bytes() + resp.headers().get(&header::ACCESS_CONTROL_MAX_AGE).unwrap().as_bytes() ); let hdr = resp .headers() @@ -932,12 +890,8 @@ mod tests { assert!(hdr.contains("accept")); assert!(hdr.contains("content-type")); - let methods = resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_METHODS) - .unwrap() - .to_str() - .unwrap(); + let methods = + resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().to_str().unwrap(); assert!(methods.contains("POST")); assert!(methods.contains("GET")); assert!(methods.contains("OPTIONS")); @@ -946,10 +900,7 @@ mod tests { let req = TestRequest::with_header("Origin", "https://www.example.com") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .header( - header::ACCESS_CONTROL_REQUEST_HEADERS, - "AUTHORIZATION,ACCEPT", - ) + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "AUTHORIZATION,ACCEPT") .method(Method::OPTIONS) .to_srv_request(); @@ -1013,10 +964,7 @@ mod tests { let req = TestRequest::default().method(Method::GET).to_srv_request(); let resp = test::call_service(&cors, req).await; - assert!(resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .is_none()); + assert!(resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).is_none()); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) @@ -1024,10 +972,7 @@ mod tests { let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://www.example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes() ); } @@ -1054,15 +999,9 @@ mod tests { let resp = test::call_service(&cors, req).await; assert_eq!( &b"*"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); - assert_eq!( - &b"Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes() + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes() ); + assert_eq!(&b"Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes()); { let headers = resp @@ -1081,22 +1020,23 @@ mod tests { } let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let cors = Cors::new() - .send_wildcard() - .disable_preflight() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(exposed_headers.clone()) - .expose_headers(exposed_headers.clone()) - .allowed_header(header::CONTENT_TYPE) - .finish() - .new_transform(fn_service(|req: WebRequest| { - ok::<_, std::convert::Infallible>(req.into_response( - HttpResponse::Ok().header(header::VARY, "Accept").finish(), - )) - })) - .await - .unwrap(); + let cors = + Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(fn_service(|req: WebRequest| { + ok::<_, std::convert::Infallible>(req.into_response( + HttpResponse::Ok().header(header::VARY, "Accept").finish(), + )) + })) + .await + .unwrap(); let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) .to_srv_request(); @@ -1121,12 +1061,8 @@ mod tests { .to_srv_request(); let resp = test::call_service(&cors, req).await; - let origins_str = resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .to_str() - .unwrap(); + let origins_str = + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().to_str().unwrap(); assert_eq!("https://www.example.com", origins_str); } @@ -1149,10 +1085,7 @@ mod tests { let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes() ); let req = TestRequest::with_header("Origin", "https://example.org") @@ -1162,10 +1095,7 @@ mod tests { let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://example.org"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes() ); } @@ -1188,10 +1118,7 @@ mod tests { let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes() ); let req = TestRequest::with_header("Origin", "https://example.org") @@ -1202,10 +1129,7 @@ mod tests { let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://example.org"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() + resp.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap().as_bytes() ); } }