From f062c249bf6da6af0190ac61ecf7cffeaa311b45 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 10 Apr 2020 22:43:32 +0600 Subject: [PATCH] convert cors --- Cargo.toml | 5 +- ntex-cors/CHANGES.md | 4 + ntex-cors/Cargo.toml | 2 +- ntex-cors/src/lib.rs | 225 ++++++++++++++++++++++++------------------- 4 files changed, 135 insertions(+), 101 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dc7c3a6c..c510e418 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,4 +5,7 @@ members = [ "ntex-identity", "ntex-multipart", "ntex-session", -] \ No newline at end of file +] + +[patch.crates-io] +ntex = { path = "../ntex/ntex/" } diff --git a/ntex-cors/CHANGES.md b/ntex-cors/CHANGES.md index 8022ea4e..1686b0ca 100644 --- a/ntex-cors/CHANGES.md +++ b/ntex-cors/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.1.0] - 2020-04-xx + +* Fork to ntex project + ## [0.2.0] - 2019-12-20 * Release diff --git a/ntex-cors/Cargo.toml b/ntex-cors/Cargo.toml index fa094b32..c738009c 100644 --- a/ntex-cors/Cargo.toml +++ b/ntex-cors/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-cors" -version = "0.2.0" +version = "0.1.0" authors = ["Nikolay Kim "] description = "Cross-origin resource sharing (CORS) for Actix applications." readme = "README.md" diff --git a/ntex-cors/src/lib.rs b/ntex-cors/src/lib.rs index 429fe9ea..d431e2bf 100644 --- a/ntex-cors/src/lib.rs +++ b/ntex-cors/src/lib.rs @@ -1,5 +1,5 @@ #![allow(clippy::borrow_interior_mutable_const, clippy::type_complexity)] -//! Cross-origin resource sharing (CORS) for Actix applications +//! Cross-origin resource sharing (CORS) for ntex applications //! //! CORS middleware could be used with application and with resource. //! Cors middleware could be used as parameter for `App::wrap()`, @@ -7,16 +7,18 @@ //! //! # Example //! -//! ```rust -//! use actix_cors::Cors; -//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; +//! ```rust,no_run +//! use ntex_cors::Cors; +//! use ntex::{http, web}; +//! use ntex::web::{App, HttpRequest, HttpResponse}; //! //! async fn index(req: HttpRequest) -> &'static str { //! "Hello world" //! } //! -//! fn main() -> std::io::Result<()> { -//! HttpServer::new(|| App::new() +//! #[ntex::main] +//! async fn main() -> std::io::Result<()> { +//! web::server(|| App::new() //! .wrap( //! Cors::new() // <- Construct CORS middleware builder //! .allowed_origin("https://www.rust-lang.org/") @@ -28,11 +30,11 @@ //! .service( //! web::resource("/index.html") //! .route(web::get().to(index)) -//! .route(web::head().to(|| HttpResponse::MethodNotAllowed())) +//! .route(web::head().to(|| async { HttpResponse::MethodNotAllowed() })) //! )) -//! .bind("127.0.0.1:8080")?; -//! -//! Ok(()) +//! .bind("127.0.0.1:8080")? +//! .run() +//! .await //! } //! ``` //! In this example custom *CORS* middleware get registered for "/index.html" @@ -42,17 +44,18 @@ use std::collections::HashSet; use std::convert::TryFrom; use std::iter::FromIterator; +use std::marker::PhantomData; use std::rc::Rc; use std::task::{Context, Poll}; -use actix_service::{Service, Transform}; -use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; -use actix_web::error::{Error, ResponseError, Result}; -use actix_web::http::header::{self, HeaderName, HeaderValue}; -use actix_web::http::{self, Error as HttpError, Method, StatusCode, Uri}; -use actix_web::HttpResponse; use derive_more::Display; use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; +use ntex::http::header::{self, HeaderName, HeaderValue}; +use ntex::http::{error::HttpError, Method, RequestHead, StatusCode, Uri}; +use ntex::service::{Service, Transform}; +use ntex::web::dev::{WebRequest, WebResponse}; +use ntex::web::HttpResponse; +use ntex::web::{DefaultError, ErrorRenderer, WebResponseError}; /// A set of errors that can occur during processing CORS #[derive(Debug, Display)] @@ -93,14 +96,11 @@ pub enum CorsError { HeadersNotAllowed, } -impl ResponseError for CorsError { +/// DefaultError renderer support +impl WebResponseError for CorsError { fn status_code(&self) -> StatusCode { StatusCode::BAD_REQUEST } - - fn error_response(&self) -> HttpResponse { - HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self).into()) - } } /// An enum signifying that some of type T is allowed, or `All` (everything is @@ -157,8 +157,8 @@ impl AllOrSome { /// # Example /// /// ```rust -/// use actix_cors::Cors; -/// use actix_web::http::header; +/// use ntex_cors::Cors; +/// use ntex::http::header; /// /// # fn main() { /// let cors = Cors::new() @@ -170,16 +170,21 @@ impl AllOrSome { /// # } /// ``` #[derive(Default)] -pub struct Cors { +pub struct Cors { cors: Option, methods: bool, - error: Option, expose_hdrs: HashSet, + error: Option, + _t: PhantomData, } -impl Cors { +impl Cors +where + Err: ErrorRenderer, + CorsError: WebResponseError, +{ /// Build a new CORS middleware instance - pub fn new() -> Cors { + pub fn new() -> Self { Cors { cors: Some(Inner { origins: AllOrSome::All, @@ -196,11 +201,12 @@ impl Cors { methods: false, error: None, expose_hdrs: HashSet::new(), + _t: PhantomData, } } /// Build a new CORS default middleware - pub fn default() -> CorsFactory { + pub fn default() -> CorsFactory { let inner = Inner { origins: AllOrSome::default(), origins_str: None, @@ -226,6 +232,7 @@ impl Cors { }; CorsFactory { inner: Rc::new(inner), + _t: PhantomData, } } @@ -246,7 +253,7 @@ impl Cors { /// Defaults to `All`. /// /// Builder panics if supplied origin is not valid uri. - pub fn allowed_origin(mut self, origin: &str) -> Cors { + pub fn allowed_origin(mut self, origin: &str) -> Self { if let Some(cors) = cors(&mut self.cors, &self.error) { match Uri::try_from(origin) { Ok(_) => { @@ -272,7 +279,7 @@ impl Cors { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` - pub fn allowed_methods(mut self, methods: U) -> Cors + pub fn allowed_methods(mut self, methods: U) -> Self where U: IntoIterator, Method: TryFrom, @@ -296,7 +303,7 @@ impl Cors { } /// Set an allowed header - pub fn allowed_header(mut self, header: H) -> Cors + pub fn allowed_header(mut self, header: H) -> Self where HeaderName: TryFrom, >::Error: Into, @@ -328,7 +335,7 @@ impl Cors { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// Defaults to `All`. - pub fn allowed_headers(mut self, headers: U) -> Cors + pub fn allowed_headers(mut self, headers: U) -> Self where U: IntoIterator, HeaderName: TryFrom, @@ -363,7 +370,7 @@ impl Cors { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// This defaults to an empty set. - pub fn expose_headers(mut self, headers: U) -> Cors + pub fn expose_headers(mut self, headers: U) -> Self where U: IntoIterator, HeaderName: TryFrom, @@ -387,7 +394,7 @@ impl Cors { /// This value is set as the `Access-Control-Max-Age` header. /// /// This defaults to `None` (unset). - pub fn max_age(mut self, max_age: usize) -> Cors { + pub fn max_age(mut self, max_age: usize) -> Self { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.max_age = Some(max_age) } @@ -406,10 +413,10 @@ impl Cors { /// This **CANNOT** be used in conjunction with `allowed_origins` set to /// `All` and `allow_credentials` set to `true`. Depending on the mode /// of usage, this will either result in an `Error:: - /// CredentialsWithWildcardOrigin` error during actix launch or runtime. + /// CredentialsWithWildcardOrigin` error during ntex launch or runtime. /// /// Defaults to `false`. - pub fn send_wildcard(mut self) -> Cors { + pub fn send_wildcard(mut self) -> Self { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.send_wildcard = true } @@ -429,7 +436,7 @@ impl Cors { /// /// Builder panics if credentials are allowed, but the Origin is set to "*". /// This is not allowed by W3C - pub fn supports_credentials(mut self) -> Cors { + pub fn supports_credentials(mut self) -> Self { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.supports_credentials = true } @@ -447,7 +454,7 @@ impl Cors { /// caches that the CORS headers are dynamic, and cannot be cached. /// /// By default `vary` header support is enabled. - pub fn disable_vary_header(mut self) -> Cors { + pub fn disable_vary_header(mut self) -> Self { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.vary_header = false } @@ -460,7 +467,7 @@ impl Cors { /// This is useful application level middleware. /// /// By default *preflight* support is enabled. - pub fn disable_preflight(mut self) -> Cors { + pub fn disable_preflight(mut self) -> Self { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.preflight = false } @@ -468,7 +475,7 @@ impl Cors { } /// Construct cors middleware - pub fn finish(self) -> CorsFactory { + pub fn finish(self) -> CorsFactory { let mut slf = if !self.methods { self.allowed_methods(vec![ Method::GET, @@ -511,13 +518,14 @@ impl Cors { CorsFactory { inner: Rc::new(cors), + _t: PhantomData, } } } fn cors<'a>( parts: &'a mut Option, - err: &Option, + err: &Option, ) -> Option<&'a mut Inner> { if err.is_some() { return None; @@ -529,27 +537,32 @@ fn cors<'a>( /// /// The Cors struct contains the settings for CORS requests to be validated and /// for responses to be generated. -pub struct CorsFactory { +pub struct CorsFactory { inner: Rc, + _t: PhantomData, } -impl Transform for CorsFactory +impl Transform for CorsFactory where - S: Service, Error = Error>, + S: Service, Response = WebResponse>, S::Future: 'static, B: 'static, + Err: ErrorRenderer, + Err::Container: From, + CorsError: WebResponseError, { - type Request = ServiceRequest; - type Response = ServiceResponse; - type Error = Error; + type Request = WebRequest; + type Response = WebResponse; + type Error = S::Error; type InitError = (); - type Transform = CorsMiddleware; + type Transform = CorsMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(CorsMiddleware { service, inner: self.inner.clone(), + _t: PhantomData, }) } } @@ -559,9 +572,10 @@ where /// The Cors struct contains the settings for CORS requests to be validated and /// for responses to be generated. #[derive(Clone)] -pub struct CorsMiddleware { +pub struct CorsMiddleware { service: S, inner: Rc, + _t: PhantomData, } struct Inner { @@ -676,25 +690,32 @@ impl Inner { } } -impl Service for CorsMiddleware +impl Service for CorsMiddleware where - S: Service, Error = Error>, + S: Service, Response = WebResponse>, S::Future: 'static, B: 'static, + Err: ErrorRenderer, + Err::Container: From, + CorsError: WebResponseError, { - type Request = ServiceRequest; - type Response = ServiceResponse; - type Error = Error; + type Request = WebRequest; + type Response = WebResponse; + type Error = S::Error; type Future = Either< - Ready>, - LocalBoxFuture<'static, Result>, + Ready>, + LocalBoxFuture<'static, Result>, >; - fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { self.service.poll_ready(cx) } - fn call(&mut self, req: ServiceRequest) -> Self::Future { + fn poll_shutdown(&self, cx: &mut Context<'_>, is_error: bool) -> Poll<()> { + self.service.poll_shutdown(cx, is_error) + } + + fn call(&self, req: WebRequest) -> Self::Future { if self.inner.preflight && Method::OPTIONS == *req.method() { if let Err(e) = self .inner @@ -702,7 +723,10 @@ where .and_then(|_| self.inner.validate_allowed_method(req.head())) .and_then(|_| self.inner.validate_allowed_headers(req.head())) { - return Either::Left(ok(req.error_response(e))); + return Either::Left(ok(req.into_response( + WebResponseError::error_response(&e) + .map_body(|_, body| body.into_body()), + ))); } // allowed headers @@ -760,7 +784,10 @@ where if req.headers().contains_key(&header::ORIGIN) { // Only check requests with a origin header. if let Err(e) = self.inner.validate_origin(req.head()) { - return Either::Left(ok(req.error_response(e))); + return Either::Left(ok(req.into_response( + WebResponseError::error_response(&e) + .map_body(|_, body| body.into_body()), + ))); } } @@ -822,20 +849,20 @@ where #[cfg(test)] mod tests { - use actix_service::{fn_service, Transform}; - use actix_web::test::{self, TestRequest}; + use ntex::service::{fn_service, Transform}; + use ntex::web::test::{self, TestRequest}; use super::*; - #[actix_rt::test] + #[ntex::test] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] async fn cors_validates_illegal_allow_credentials() { let _cors = Cors::new().supports_credentials().send_wildcard().finish(); } - #[actix_rt::test] + #[ntex::test] async fn validate_origin_allows_all_origins() { - let mut cors = Cors::new() + let cors = Cors::new() .finish() .new_transform(test::ok_service()) .await @@ -843,24 +870,24 @@ mod tests { let req = TestRequest::with_header("Origin", "https://www.example.com") .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } - #[actix_rt::test] + #[ntex::test] async fn default() { - let mut cors = Cors::default() + let cors = Cors::default() .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::with_header("Origin", "https://www.example.com") .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } - #[actix_rt::test] + #[ntex::test] async fn test_preflight() { let mut cors = Cors::new() .send_wildcard() @@ -880,7 +907,7 @@ mod tests { assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let req = TestRequest::with_header("Origin", "https://www.example.com") @@ -900,7 +927,7 @@ mod tests { .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!( &b"*"[..], resp.headers() @@ -946,11 +973,11 @@ mod tests { .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } - // #[actix_rt::test] + // #[ntex::test] // #[should_panic(expected = "MissingOrigin")] // async fn test_validate_missing_origin() { // let cors = Cors::build() @@ -960,7 +987,7 @@ mod tests { // cors.start(&req).unwrap(); // } - #[actix_rt::test] + #[ntex::test] #[should_panic(expected = "OriginNotAllowed")] async fn test_validate_not_allowed_origin() { let cors = Cors::new() @@ -978,9 +1005,9 @@ mod tests { cors.inner.validate_allowed_headers(req.head()).unwrap(); } - #[actix_rt::test] + #[ntex::test] async fn test_validate_origin() { - let mut cors = Cors::new() + let cors = Cors::new() .allowed_origin("https://www.example.com") .finish() .new_transform(test::ok_service()) @@ -991,13 +1018,13 @@ mod tests { .method(Method::GET) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!(resp.status(), StatusCode::OK); } - #[actix_rt::test] + #[ntex::test] async fn test_no_origin_response() { - let mut cors = Cors::new() + let cors = Cors::new() .disable_preflight() .finish() .new_transform(test::ok_service()) @@ -1005,7 +1032,7 @@ mod tests { .unwrap(); let req = TestRequest::default().method(Method::GET).to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert!(resp .headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) @@ -1014,7 +1041,7 @@ mod tests { let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://www.example.com"[..], resp.headers() @@ -1024,10 +1051,10 @@ mod tests { ); } - #[actix_rt::test] + #[ntex::test] async fn test_response() { let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() + let cors = Cors::new() .send_wildcard() .disable_preflight() .max_age(3600) @@ -1044,7 +1071,7 @@ mod tests { .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!( &b"*"[..], resp.headers() @@ -1074,7 +1101,7 @@ mod tests { } let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() + let cors = Cors::new() .send_wildcard() .disable_preflight() .max_age(3600) @@ -1083,8 +1110,8 @@ mod tests { .expose_headers(exposed_headers.clone()) .allowed_header(header::CONTENT_TYPE) .finish() - .new_transform(fn_service(|req: ServiceRequest| { - ok(req.into_response( + .new_transform(fn_service(|req: WebRequest| { + ok::<_, std::convert::Infallible>(req.into_response( HttpResponse::Ok().header(header::VARY, "Accept").finish(), )) })) @@ -1093,13 +1120,13 @@ mod tests { let req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!( &b"Accept, Origin"[..], resp.headers().get(header::VARY).unwrap().as_bytes() ); - let mut cors = Cors::new() + let cors = Cors::new() .disable_vary_header() .allowed_origin("https://www.example.com") .allowed_origin("https://www.google.com") @@ -1112,7 +1139,7 @@ mod tests { .method(Method::OPTIONS) .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; let origins_str = resp .headers() @@ -1124,9 +1151,9 @@ mod tests { assert_eq!("https://www.example.com", origins_str); } - #[actix_rt::test] + #[ntex::test] async fn test_multiple_origins() { - let mut cors = Cors::new() + let cors = Cors::new() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) @@ -1139,7 +1166,7 @@ mod tests { .method(Method::GET) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://example.com"[..], resp.headers() @@ -1152,7 +1179,7 @@ mod tests { .method(Method::GET) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://example.org"[..], resp.headers() @@ -1162,9 +1189,9 @@ mod tests { ); } - #[actix_rt::test] + #[ntex::test] async fn test_multiple_origins_preflight() { - let mut cors = Cors::new() + let cors = Cors::new() .allowed_origin("https://example.com") .allowed_origin("https://example.org") .allowed_methods(vec![Method::GET]) @@ -1178,7 +1205,7 @@ mod tests { .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://example.com"[..], resp.headers() @@ -1192,7 +1219,7 @@ mod tests { .method(Method::OPTIONS) .to_srv_request(); - let resp = test::call_service(&mut cors, req).await; + let resp = test::call_service(&cors, req).await; assert_eq!( &b"https://example.org"[..], resp.headers()