convert cors

This commit is contained in:
Nikolay Kim 2020-04-10 22:43:32 +06:00
parent 864f23c3be
commit f062c249bf
4 changed files with 135 additions and 101 deletions

View file

@ -5,4 +5,7 @@ members = [
"ntex-identity", "ntex-identity",
"ntex-multipart", "ntex-multipart",
"ntex-session", "ntex-session",
] ]
[patch.crates-io]
ntex = { path = "../ntex/ntex/" }

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.1.0] - 2020-04-xx
* Fork to ntex project
## [0.2.0] - 2019-12-20 ## [0.2.0] - 2019-12-20
* Release * Release

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-cors" name = "ntex-cors"
version = "0.2.0" version = "0.1.0"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Cross-origin resource sharing (CORS) for Actix applications." description = "Cross-origin resource sharing (CORS) for Actix applications."
readme = "README.md" readme = "README.md"

View file

@ -1,5 +1,5 @@
#![allow(clippy::borrow_interior_mutable_const, clippy::type_complexity)] #![allow(clippy::borrow_interior_mutable_const, clippy::type_complexity)]
//! Cross-origin resource sharing (CORS) for Actix applications //! Cross-origin resource sharing (CORS) for ntex applications
//! //!
//! CORS middleware could be used with application and with resource. //! CORS middleware could be used with application and with resource.
//! Cors middleware could be used as parameter for `App::wrap()`, //! Cors middleware could be used as parameter for `App::wrap()`,
@ -7,16 +7,18 @@
//! //!
//! # Example //! # Example
//! //!
//! ```rust //! ```rust,no_run
//! use actix_cors::Cors; //! use ntex_cors::Cors;
//! use actix_web::{http, web, App, HttpRequest, HttpResponse, HttpServer}; //! use ntex::{http, web};
//! use ntex::web::{App, HttpRequest, HttpResponse};
//! //!
//! async fn index(req: HttpRequest) -> &'static str { //! async fn index(req: HttpRequest) -> &'static str {
//! "Hello world" //! "Hello world"
//! } //! }
//! //!
//! fn main() -> std::io::Result<()> { //! #[ntex::main]
//! HttpServer::new(|| App::new() //! async fn main() -> std::io::Result<()> {
//! web::server(|| App::new()
//! .wrap( //! .wrap(
//! Cors::new() // <- Construct CORS middleware builder //! Cors::new() // <- Construct CORS middleware builder
//! .allowed_origin("https://www.rust-lang.org/") //! .allowed_origin("https://www.rust-lang.org/")
@ -28,11 +30,11 @@
//! .service( //! .service(
//! web::resource("/index.html") //! web::resource("/index.html")
//! .route(web::get().to(index)) //! .route(web::get().to(index))
//! .route(web::head().to(|| HttpResponse::MethodNotAllowed())) //! .route(web::head().to(|| async { HttpResponse::MethodNotAllowed() }))
//! )) //! ))
//! .bind("127.0.0.1:8080")?; //! .bind("127.0.0.1:8080")?
//! //! .run()
//! Ok(()) //! .await
//! } //! }
//! ``` //! ```
//! In this example custom *CORS* middleware get registered for "/index.html" //! In this example custom *CORS* middleware get registered for "/index.html"
@ -42,17 +44,18 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::iter::FromIterator; use std::iter::FromIterator;
use std::marker::PhantomData;
use std::rc::Rc; use std::rc::Rc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use actix_service::{Service, Transform};
use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse};
use actix_web::error::{Error, ResponseError, Result};
use actix_web::http::header::{self, HeaderName, HeaderValue};
use actix_web::http::{self, Error as HttpError, Method, StatusCode, Uri};
use actix_web::HttpResponse;
use derive_more::Display; use derive_more::Display;
use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready};
use ntex::http::header::{self, HeaderName, HeaderValue};
use ntex::http::{error::HttpError, Method, RequestHead, StatusCode, Uri};
use ntex::service::{Service, Transform};
use ntex::web::dev::{WebRequest, WebResponse};
use ntex::web::HttpResponse;
use ntex::web::{DefaultError, ErrorRenderer, WebResponseError};
/// A set of errors that can occur during processing CORS /// A set of errors that can occur during processing CORS
#[derive(Debug, Display)] #[derive(Debug, Display)]
@ -93,14 +96,11 @@ pub enum CorsError {
HeadersNotAllowed, HeadersNotAllowed,
} }
impl ResponseError for CorsError { /// DefaultError renderer support
impl WebResponseError<DefaultError> for CorsError {
fn status_code(&self) -> StatusCode { fn status_code(&self) -> StatusCode {
StatusCode::BAD_REQUEST StatusCode::BAD_REQUEST
} }
fn error_response(&self) -> HttpResponse {
HttpResponse::with_body(StatusCode::BAD_REQUEST, format!("{}", self).into())
}
} }
/// An enum signifying that some of type T is allowed, or `All` (everything is /// An enum signifying that some of type T is allowed, or `All` (everything is
@ -157,8 +157,8 @@ impl<T> AllOrSome<T> {
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
/// use actix_cors::Cors; /// use ntex_cors::Cors;
/// use actix_web::http::header; /// use ntex::http::header;
/// ///
/// # fn main() { /// # fn main() {
/// let cors = Cors::new() /// let cors = Cors::new()
@ -170,16 +170,21 @@ impl<T> AllOrSome<T> {
/// # } /// # }
/// ``` /// ```
#[derive(Default)] #[derive(Default)]
pub struct Cors { pub struct Cors<Err: ErrorRenderer> {
cors: Option<Inner>, cors: Option<Inner>,
methods: bool, methods: bool,
error: Option<http::Error>,
expose_hdrs: HashSet<HeaderName>, expose_hdrs: HashSet<HeaderName>,
error: Option<HttpError>,
_t: PhantomData<Err>,
} }
impl Cors { impl<Err: ErrorRenderer> Cors<Err>
where
Err: ErrorRenderer,
CorsError: WebResponseError<Err>,
{
/// Build a new CORS middleware instance /// Build a new CORS middleware instance
pub fn new() -> Cors { pub fn new() -> Self {
Cors { Cors {
cors: Some(Inner { cors: Some(Inner {
origins: AllOrSome::All, origins: AllOrSome::All,
@ -196,11 +201,12 @@ impl Cors {
methods: false, methods: false,
error: None, error: None,
expose_hdrs: HashSet::new(), expose_hdrs: HashSet::new(),
_t: PhantomData,
} }
} }
/// Build a new CORS default middleware /// Build a new CORS default middleware
pub fn default() -> CorsFactory { pub fn default() -> CorsFactory<Err> {
let inner = Inner { let inner = Inner {
origins: AllOrSome::default(), origins: AllOrSome::default(),
origins_str: None, origins_str: None,
@ -226,6 +232,7 @@ impl Cors {
}; };
CorsFactory { CorsFactory {
inner: Rc::new(inner), inner: Rc::new(inner),
_t: PhantomData,
} }
} }
@ -246,7 +253,7 @@ impl Cors {
/// Defaults to `All`. /// Defaults to `All`.
/// ///
/// Builder panics if supplied origin is not valid uri. /// Builder panics if supplied origin is not valid uri.
pub fn allowed_origin(mut self, origin: &str) -> Cors { pub fn allowed_origin(mut self, origin: &str) -> Self {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
match Uri::try_from(origin) { match Uri::try_from(origin) {
Ok(_) => { Ok(_) => {
@ -272,7 +279,7 @@ impl Cors {
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
/// ///
/// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]` /// Defaults to `[GET, HEAD, POST, OPTIONS, PUT, PATCH, DELETE]`
pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors pub fn allowed_methods<U, M>(mut self, methods: U) -> Self
where where
U: IntoIterator<Item = M>, U: IntoIterator<Item = M>,
Method: TryFrom<M>, Method: TryFrom<M>,
@ -296,7 +303,7 @@ impl Cors {
} }
/// Set an allowed header /// Set an allowed header
pub fn allowed_header<H>(mut self, header: H) -> Cors pub fn allowed_header<H>(mut self, header: H) -> Self
where where
HeaderName: TryFrom<H>, HeaderName: TryFrom<H>,
<HeaderName as TryFrom<H>>::Error: Into<HttpError>, <HeaderName as TryFrom<H>>::Error: Into<HttpError>,
@ -328,7 +335,7 @@ impl Cors {
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
/// ///
/// Defaults to `All`. /// Defaults to `All`.
pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors pub fn allowed_headers<U, H>(mut self, headers: U) -> Self
where where
U: IntoIterator<Item = H>, U: IntoIterator<Item = H>,
HeaderName: TryFrom<H>, HeaderName: TryFrom<H>,
@ -363,7 +370,7 @@ impl Cors {
/// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model).
/// ///
/// This defaults to an empty set. /// This defaults to an empty set.
pub fn expose_headers<U, H>(mut self, headers: U) -> Cors pub fn expose_headers<U, H>(mut self, headers: U) -> Self
where where
U: IntoIterator<Item = H>, U: IntoIterator<Item = H>,
HeaderName: TryFrom<H>, HeaderName: TryFrom<H>,
@ -387,7 +394,7 @@ impl Cors {
/// This value is set as the `Access-Control-Max-Age` header. /// This value is set as the `Access-Control-Max-Age` header.
/// ///
/// This defaults to `None` (unset). /// This defaults to `None` (unset).
pub fn max_age(mut self, max_age: usize) -> Cors { pub fn max_age(mut self, max_age: usize) -> Self {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.max_age = Some(max_age) cors.max_age = Some(max_age)
} }
@ -406,10 +413,10 @@ impl Cors {
/// This **CANNOT** be used in conjunction with `allowed_origins` set to /// This **CANNOT** be used in conjunction with `allowed_origins` set to
/// `All` and `allow_credentials` set to `true`. Depending on the mode /// `All` and `allow_credentials` set to `true`. Depending on the mode
/// of usage, this will either result in an `Error:: /// of usage, this will either result in an `Error::
/// CredentialsWithWildcardOrigin` error during actix launch or runtime. /// CredentialsWithWildcardOrigin` error during ntex launch or runtime.
/// ///
/// Defaults to `false`. /// Defaults to `false`.
pub fn send_wildcard(mut self) -> Cors { pub fn send_wildcard(mut self) -> Self {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.send_wildcard = true cors.send_wildcard = true
} }
@ -429,7 +436,7 @@ impl Cors {
/// ///
/// Builder panics if credentials are allowed, but the Origin is set to "*". /// Builder panics if credentials are allowed, but the Origin is set to "*".
/// This is not allowed by W3C /// This is not allowed by W3C
pub fn supports_credentials(mut self) -> Cors { pub fn supports_credentials(mut self) -> Self {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.supports_credentials = true cors.supports_credentials = true
} }
@ -447,7 +454,7 @@ impl Cors {
/// caches that the CORS headers are dynamic, and cannot be cached. /// caches that the CORS headers are dynamic, and cannot be cached.
/// ///
/// By default `vary` header support is enabled. /// By default `vary` header support is enabled.
pub fn disable_vary_header(mut self) -> Cors { pub fn disable_vary_header(mut self) -> Self {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.vary_header = false cors.vary_header = false
} }
@ -460,7 +467,7 @@ impl Cors {
/// This is useful application level middleware. /// This is useful application level middleware.
/// ///
/// By default *preflight* support is enabled. /// By default *preflight* support is enabled.
pub fn disable_preflight(mut self) -> Cors { pub fn disable_preflight(mut self) -> Self {
if let Some(cors) = cors(&mut self.cors, &self.error) { if let Some(cors) = cors(&mut self.cors, &self.error) {
cors.preflight = false cors.preflight = false
} }
@ -468,7 +475,7 @@ impl Cors {
} }
/// Construct cors middleware /// Construct cors middleware
pub fn finish(self) -> CorsFactory { pub fn finish(self) -> CorsFactory<Err> {
let mut slf = if !self.methods { let mut slf = if !self.methods {
self.allowed_methods(vec![ self.allowed_methods(vec![
Method::GET, Method::GET,
@ -511,13 +518,14 @@ impl Cors {
CorsFactory { CorsFactory {
inner: Rc::new(cors), inner: Rc::new(cors),
_t: PhantomData,
} }
} }
} }
fn cors<'a>( fn cors<'a>(
parts: &'a mut Option<Inner>, parts: &'a mut Option<Inner>,
err: &Option<http::Error>, err: &Option<HttpError>,
) -> Option<&'a mut Inner> { ) -> Option<&'a mut Inner> {
if err.is_some() { if err.is_some() {
return None; return None;
@ -529,27 +537,32 @@ fn cors<'a>(
/// ///
/// The Cors struct contains the settings for CORS requests to be validated and /// The Cors struct contains the settings for CORS requests to be validated and
/// for responses to be generated. /// for responses to be generated.
pub struct CorsFactory { pub struct CorsFactory<Err> {
inner: Rc<Inner>, inner: Rc<Inner>,
_t: PhantomData<Err>,
} }
impl<S, B> Transform<S> for CorsFactory impl<S, B, Err> Transform<S> for CorsFactory<Err>
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>, S: Service<Request = WebRequest<Err>, Response = WebResponse<B>>,
S::Future: 'static, S::Future: 'static,
B: 'static, B: 'static,
Err: ErrorRenderer,
Err::Container: From<S::Error>,
CorsError: WebResponseError<Err>,
{ {
type Request = ServiceRequest; type Request = WebRequest<Err>;
type Response = ServiceResponse<B>; type Response = WebResponse<B>;
type Error = Error; type Error = S::Error;
type InitError = (); type InitError = ();
type Transform = CorsMiddleware<S>; type Transform = CorsMiddleware<S, Err>;
type Future = Ready<Result<Self::Transform, Self::InitError>>; type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future { fn new_transform(&self, service: S) -> Self::Future {
ok(CorsMiddleware { ok(CorsMiddleware {
service, service,
inner: self.inner.clone(), inner: self.inner.clone(),
_t: PhantomData,
}) })
} }
} }
@ -559,9 +572,10 @@ where
/// The Cors struct contains the settings for CORS requests to be validated and /// The Cors struct contains the settings for CORS requests to be validated and
/// for responses to be generated. /// for responses to be generated.
#[derive(Clone)] #[derive(Clone)]
pub struct CorsMiddleware<S> { pub struct CorsMiddleware<S, Err> {
service: S, service: S,
inner: Rc<Inner>, inner: Rc<Inner>,
_t: PhantomData<Err>,
} }
struct Inner { struct Inner {
@ -676,25 +690,32 @@ impl Inner {
} }
} }
impl<S, B> Service for CorsMiddleware<S> impl<S, B, Err> Service for CorsMiddleware<S, Err>
where where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>, S: Service<Request = WebRequest<Err>, Response = WebResponse<B>>,
S::Future: 'static, S::Future: 'static,
B: 'static, B: 'static,
Err: ErrorRenderer,
Err::Container: From<S::Error>,
CorsError: WebResponseError<Err>,
{ {
type Request = ServiceRequest; type Request = WebRequest<Err>;
type Response = ServiceResponse<B>; type Response = WebResponse<B>;
type Error = Error; type Error = S::Error;
type Future = Either< type Future = Either<
Ready<Result<Self::Response, Error>>, Ready<Result<Self::Response, S::Error>>,
LocalBoxFuture<'static, Result<Self::Response, Error>>, LocalBoxFuture<'static, Result<Self::Response, S::Error>>,
>; >;
fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx) self.service.poll_ready(cx)
} }
fn call(&mut self, req: ServiceRequest) -> Self::Future { fn poll_shutdown(&self, cx: &mut Context<'_>, is_error: bool) -> Poll<()> {
self.service.poll_shutdown(cx, is_error)
}
fn call(&self, req: WebRequest<Err>) -> Self::Future {
if self.inner.preflight && Method::OPTIONS == *req.method() { if self.inner.preflight && Method::OPTIONS == *req.method() {
if let Err(e) = self if let Err(e) = self
.inner .inner
@ -702,7 +723,10 @@ where
.and_then(|_| self.inner.validate_allowed_method(req.head())) .and_then(|_| self.inner.validate_allowed_method(req.head()))
.and_then(|_| self.inner.validate_allowed_headers(req.head())) .and_then(|_| self.inner.validate_allowed_headers(req.head()))
{ {
return Either::Left(ok(req.error_response(e))); return Either::Left(ok(req.into_response(
WebResponseError::error_response(&e)
.map_body(|_, body| body.into_body()),
)));
} }
// allowed headers // allowed headers
@ -760,7 +784,10 @@ where
if req.headers().contains_key(&header::ORIGIN) { if req.headers().contains_key(&header::ORIGIN) {
// Only check requests with a origin header. // Only check requests with a origin header.
if let Err(e) = self.inner.validate_origin(req.head()) { if let Err(e) = self.inner.validate_origin(req.head()) {
return Either::Left(ok(req.error_response(e))); return Either::Left(ok(req.into_response(
WebResponseError::error_response(&e)
.map_body(|_, body| body.into_body()),
)));
} }
} }
@ -822,20 +849,20 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use actix_service::{fn_service, Transform}; use ntex::service::{fn_service, Transform};
use actix_web::test::{self, TestRequest}; use ntex::web::test::{self, TestRequest};
use super::*; use super::*;
#[actix_rt::test] #[ntex::test]
#[should_panic(expected = "Credentials are allowed, but the Origin is set to")] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")]
async fn cors_validates_illegal_allow_credentials() { async fn cors_validates_illegal_allow_credentials() {
let _cors = Cors::new().supports_credentials().send_wildcard().finish(); let _cors = Cors::new().supports_credentials().send_wildcard().finish();
} }
#[actix_rt::test] #[ntex::test]
async fn validate_origin_allows_all_origins() { async fn validate_origin_allows_all_origins() {
let mut cors = Cors::new() let cors = Cors::new()
.finish() .finish()
.new_transform(test::ok_service()) .new_transform(test::ok_service())
.await .await
@ -843,24 +870,24 @@ mod tests {
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[actix_rt::test] #[ntex::test]
async fn default() { async fn default() {
let mut cors = Cors::default() let cors = Cors::default()
.new_transform(test::ok_service()) .new_transform(test::ok_service())
.await .await
.unwrap(); .unwrap();
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[actix_rt::test] #[ntex::test]
async fn test_preflight() { async fn test_preflight() {
let mut cors = Cors::new() let mut cors = Cors::new()
.send_wildcard() .send_wildcard()
@ -880,7 +907,7 @@ mod tests {
assert!(cors.inner.validate_allowed_method(req.head()).is_err()); assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
@ -900,7 +927,7 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
@ -946,11 +973,11 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
// #[actix_rt::test] // #[ntex::test]
// #[should_panic(expected = "MissingOrigin")] // #[should_panic(expected = "MissingOrigin")]
// async fn test_validate_missing_origin() { // async fn test_validate_missing_origin() {
// let cors = Cors::build() // let cors = Cors::build()
@ -960,7 +987,7 @@ mod tests {
// cors.start(&req).unwrap(); // cors.start(&req).unwrap();
// } // }
#[actix_rt::test] #[ntex::test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
async fn test_validate_not_allowed_origin() { async fn test_validate_not_allowed_origin() {
let cors = Cors::new() let cors = Cors::new()
@ -978,9 +1005,9 @@ mod tests {
cors.inner.validate_allowed_headers(req.head()).unwrap(); cors.inner.validate_allowed_headers(req.head()).unwrap();
} }
#[actix_rt::test] #[ntex::test]
async fn test_validate_origin() { async fn test_validate_origin() {
let mut cors = Cors::new() let cors = Cors::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish() .finish()
.new_transform(test::ok_service()) .new_transform(test::ok_service())
@ -991,13 +1018,13 @@ mod tests {
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[actix_rt::test] #[ntex::test]
async fn test_no_origin_response() { async fn test_no_origin_response() {
let mut cors = Cors::new() let cors = Cors::new()
.disable_preflight() .disable_preflight()
.finish() .finish()
.new_transform(test::ok_service()) .new_transform(test::ok_service())
@ -1005,7 +1032,7 @@ mod tests {
.unwrap(); .unwrap();
let req = TestRequest::default().method(Method::GET).to_srv_request(); let req = TestRequest::default().method(Method::GET).to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert!(resp assert!(resp
.headers() .headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
@ -1014,7 +1041,7 @@ mod tests {
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
&b"https://www.example.com"[..], &b"https://www.example.com"[..],
resp.headers() resp.headers()
@ -1024,10 +1051,10 @@ mod tests {
); );
} }
#[actix_rt::test] #[ntex::test]
async fn test_response() { async fn test_response() {
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::new() let cors = Cors::new()
.send_wildcard() .send_wildcard()
.disable_preflight() .disable_preflight()
.max_age(3600) .max_age(3600)
@ -1044,7 +1071,7 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
&b"*"[..], &b"*"[..],
resp.headers() resp.headers()
@ -1074,7 +1101,7 @@ mod tests {
} }
let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT];
let mut cors = Cors::new() let cors = Cors::new()
.send_wildcard() .send_wildcard()
.disable_preflight() .disable_preflight()
.max_age(3600) .max_age(3600)
@ -1083,8 +1110,8 @@ mod tests {
.expose_headers(exposed_headers.clone()) .expose_headers(exposed_headers.clone())
.allowed_header(header::CONTENT_TYPE) .allowed_header(header::CONTENT_TYPE)
.finish() .finish()
.new_transform(fn_service(|req: ServiceRequest| { .new_transform(fn_service(|req: WebRequest<DefaultError>| {
ok(req.into_response( ok::<_, std::convert::Infallible>(req.into_response(
HttpResponse::Ok().header(header::VARY, "Accept").finish(), HttpResponse::Ok().header(header::VARY, "Accept").finish(),
)) ))
})) }))
@ -1093,13 +1120,13 @@ mod tests {
let req = TestRequest::with_header("Origin", "https://www.example.com") let req = TestRequest::with_header("Origin", "https://www.example.com")
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
&b"Accept, Origin"[..], &b"Accept, Origin"[..],
resp.headers().get(header::VARY).unwrap().as_bytes() resp.headers().get(header::VARY).unwrap().as_bytes()
); );
let mut cors = Cors::new() let cors = Cors::new()
.disable_vary_header() .disable_vary_header()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.allowed_origin("https://www.google.com") .allowed_origin("https://www.google.com")
@ -1112,7 +1139,7 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
let origins_str = resp let origins_str = resp
.headers() .headers()
@ -1124,9 +1151,9 @@ mod tests {
assert_eq!("https://www.example.com", origins_str); assert_eq!("https://www.example.com", origins_str);
} }
#[actix_rt::test] #[ntex::test]
async fn test_multiple_origins() { async fn test_multiple_origins() {
let mut cors = Cors::new() let cors = Cors::new()
.allowed_origin("https://example.com") .allowed_origin("https://example.com")
.allowed_origin("https://example.org") .allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET]) .allowed_methods(vec![Method::GET])
@ -1139,7 +1166,7 @@ mod tests {
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
@ -1152,7 +1179,7 @@ mod tests {
.method(Method::GET) .method(Method::GET)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()
@ -1162,9 +1189,9 @@ mod tests {
); );
} }
#[actix_rt::test] #[ntex::test]
async fn test_multiple_origins_preflight() { async fn test_multiple_origins_preflight() {
let mut cors = Cors::new() let cors = Cors::new()
.allowed_origin("https://example.com") .allowed_origin("https://example.com")
.allowed_origin("https://example.org") .allowed_origin("https://example.org")
.allowed_methods(vec![Method::GET]) .allowed_methods(vec![Method::GET])
@ -1178,7 +1205,7 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.com"[..], &b"https://example.com"[..],
resp.headers() resp.headers()
@ -1192,7 +1219,7 @@ mod tests {
.method(Method::OPTIONS) .method(Method::OPTIONS)
.to_srv_request(); .to_srv_request();
let resp = test::call_service(&mut cors, req).await; let resp = test::call_service(&cors, req).await;
assert_eq!( assert_eq!(
&b"https://example.org"[..], &b"https://example.org"[..],
resp.headers() resp.headers()