Change state management behaviour

This commit is contained in:
Nikolay Kim 2022-11-25 16:45:42 +01:00
parent d947f7f08c
commit c8530d65a5
17 changed files with 348 additions and 443 deletions

View file

@ -8,7 +8,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
version: version:
- 1.59.0 # MSRV - 1.60.0 # MSRV
- stable - stable
- nightly - nightly
@ -43,7 +43,7 @@ jobs:
key: ${{ matrix.version }}-x86_64-unknown-linux-gnu-cargo-index-trimmed-${{ hashFiles('**/Cargo.lock') }} key: ${{ matrix.version }}-x86_64-unknown-linux-gnu-cargo-index-trimmed-${{ hashFiles('**/Cargo.lock') }}
- name: Cache cargo tarpaulin - name: Cache cargo tarpaulin
if: matrix.version == '1.59.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') if: matrix.version == '1.60.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
uses: actions/cache@v1 uses: actions/cache@v1
with: with:
path: ~/.cargo/bin path: ~/.cargo/bin
@ -64,26 +64,26 @@ jobs:
cargo test --no-default-features --no-fail-fast --features="async-std,cookie,url,compress,openssl,rustls" --lib -- --test-threads 1 cargo test --no-default-features --no-fail-fast --features="async-std,cookie,url,compress,openssl,rustls" --lib -- --test-threads 1
- name: Install tarpaulin - name: Install tarpaulin
if: matrix.version == '1.59.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') if: matrix.version == '1.60.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
continue-on-error: true continue-on-error: true
run: | run: |
cargo install cargo-tarpaulin cargo install cargo-tarpaulin
- name: Generate coverage report - name: Generate coverage report
if: matrix.version == '1.59.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') if: matrix.version == '1.60.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
continue-on-error: true continue-on-error: true
run: | run: |
cargo tarpaulin --out Xml --all --all-features cargo tarpaulin --out Xml --all --all-features
- name: Generate coverage report (glommio) - name: Generate coverage report (glommio)
if: matrix.version == '1.59.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') if: matrix.version == '1.60.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
continue-on-error: true continue-on-error: true
run: | run: |
cd ntex cd ntex
sudo -E env PATH="$PATH" bash -c "ulimit -l 512 && ulimit -a && cargo tarpaulin --out Xml --no-default-features --features=\"glommio,cookie,url,compress,openssl,rustls\"" sudo -E env PATH="$PATH" bash -c "ulimit -l 512 && ulimit -a && cargo tarpaulin --out Xml --no-default-features --features=\"glommio,cookie,url,compress,openssl,rustls\""
- name: Upload to Codecov - name: Upload to Codecov
if: matrix.version == '1.59.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') if: matrix.version == '1.60.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
continue-on-error: true continue-on-error: true
uses: codecov/codecov-action@v2 uses: codecov/codecov-action@v2
with: with:

View file

@ -2,6 +2,9 @@
## [0.5.30] - 2022-11-25 ## [0.5.30] - 2022-11-25
* Change `App::state()` behaviour
* Remove `App::app_state()` method
## [0.5.29] - 2022-11-03 ## [0.5.29] - 2022-11-03

View file

@ -54,7 +54,7 @@ ntex-http = "0.1.7"
ntex-router = "0.5.1" ntex-router = "0.5.1"
ntex-service = "0.3.2" ntex-service = "0.3.2"
ntex-macros = "0.1.3" ntex-macros = "0.1.3"
ntex-util = "0.1.17" ntex-util = "0.1.18"
ntex-bytes = "0.1.16" ntex-bytes = "0.1.16"
ntex-h2 = "0.1.5" ntex-h2 = "0.1.5"
ntex-rt = "0.4.6" ntex-rt = "0.4.6"

View file

@ -16,13 +16,12 @@ use super::resource::Resource;
use super::response::WebResponse; use super::response::WebResponse;
use super::route::Route; use super::route::Route;
use super::service::{AppServiceFactory, ServiceFactoryWrapper, WebServiceFactory}; use super::service::{AppServiceFactory, ServiceFactoryWrapper, WebServiceFactory};
use super::types::state::{State, StateFactory};
use super::{DefaultError, ErrorRenderer}; use super::{DefaultError, ErrorRenderer};
type HttpNewService<Err: ErrorRenderer> = type HttpNewService<Err: ErrorRenderer> =
BoxServiceFactory<(), WebRequest<Err>, WebResponse, Err::Container, ()>; BoxServiceFactory<(), WebRequest<Err>, WebResponse, Err::Container, ()>;
type FnStateFactory = type FnStateFactory =
Box<dyn Fn() -> Pin<Box<dyn Future<Output = Result<Box<dyn StateFactory>, ()>>>>>; Box<dyn Fn(Extensions) -> Pin<Box<dyn Future<Output = Result<Extensions, ()>>>>>;
/// Application builder - structure that follows the builder pattern /// Application builder - structure that follows the builder pattern
/// for building application instances. /// for building application instances.
@ -31,10 +30,9 @@ pub struct App<M, F, Err: ErrorRenderer = DefaultError> {
filter: PipelineFactory<F, WebRequest<Err>>, filter: PipelineFactory<F, WebRequest<Err>>,
services: Vec<Box<dyn AppServiceFactory<Err>>>, services: Vec<Box<dyn AppServiceFactory<Err>>>,
default: Option<Rc<HttpNewService<Err>>>, default: Option<Rc<HttpNewService<Err>>>,
state: Vec<Box<dyn StateFactory>>,
state_factories: Vec<FnStateFactory>,
external: Vec<ResourceDef>, external: Vec<ResourceDef>,
extensions: Extensions, extensions: Extensions,
state_factories: Vec<FnStateFactory>,
error_renderer: Err, error_renderer: Err,
case_insensitive: bool, case_insensitive: bool,
} }
@ -45,7 +43,6 @@ impl App<Identity, Filter<DefaultError>, DefaultError> {
App { App {
middleware: Identity, middleware: Identity,
filter: pipeline_factory(Filter::new()), filter: pipeline_factory(Filter::new()),
state: Vec::new(),
state_factories: Vec::new(), state_factories: Vec::new(),
services: Vec::new(), services: Vec::new(),
default: None, default: None,
@ -63,7 +60,6 @@ impl<Err: ErrorRenderer> App<Identity, Filter<Err>, Err> {
App { App {
middleware: Identity, middleware: Identity,
filter: pipeline_factory(Filter::new()), filter: pipeline_factory(Filter::new()),
state: Vec::new(),
state_factories: Vec::new(), state_factories: Vec::new(),
services: Vec::new(), services: Vec::new(),
default: None, default: None,
@ -86,16 +82,19 @@ where
T::Future: 'static, T::Future: 'static,
Err: ErrorRenderer, Err: ErrorRenderer,
{ {
/// Set application state. Application state could be accessed /// Set application level arbitrary state item.
/// by using `State<T>` extractor where `T` is state type. ///
/// Application state stored with `App::app_state()` method is available
/// via `HttpRequest::app_state()` method at runtime.
///
/// This method could be used for storing `State<T>` as well, in that case
/// state could be accessed by using `State<T>` extractor.
/// ///
/// **Note**: http server accepts an application factory rather than /// **Note**: http server accepts an application factory rather than
/// an application instance. Http server constructs an application /// an application instance. Http server constructs an application
/// instance for each thread, thus application state must be constructed /// instance for each thread, thus application state must be constructed
/// multiple times. If you want to share state between different /// multiple times. If you want to share state between different
/// threads, a shared object should be used, e.g. `Arc`. Internally `State` type /// threads, a shared object should be used, e.g. `Arc`.
/// uses `Arc` so statw could be created outside of app factory and clones could
/// be stored via `App::app_state()` method.
/// ///
/// ```rust /// ```rust
/// use std::cell::Cell; /// use std::cell::Cell;
@ -117,7 +116,7 @@ where
/// ); /// );
/// ``` /// ```
pub fn state<U: 'static>(mut self, state: U) -> Self { pub fn state<U: 'static>(mut self, state: U) -> Self {
self.state.push(Box::new(State::new(state))); self.extensions.insert(state);
self self
} }
@ -131,7 +130,7 @@ where
D: 'static, D: 'static,
E: fmt::Debug, E: fmt::Debug,
{ {
self.state_factories.push(Box::new(move || { self.state_factories.push(Box::new(move |mut ext| {
let fut = state(); let fut = state();
Box::pin(async move { Box::pin(async move {
match fut.await { match fut.await {
@ -140,8 +139,8 @@ where
Err(()) Err(())
} }
Ok(st) => { Ok(st) => {
let st: Box<dyn StateFactory> = Box::new(State::new(st)); ext.insert(st);
Ok(st) Ok(ext)
} }
} }
}) })
@ -149,18 +148,6 @@ where
self self
} }
/// Set application level arbitrary state item.
///
/// Application state stored with `App::app_state()` method is available
/// via `HttpRequest::app_state()` method at runtime.
///
/// This method could be used for storing `State<T>` as well, in that case
/// state could be accessed by using `State<T>` extractor.
pub fn app_state<U: 'static>(mut self, ext: U) -> Self {
self.extensions.insert(ext);
self
}
/// Run external configuration as part of the application building /// Run external configuration as part of the application building
/// process /// process
/// ///
@ -192,9 +179,9 @@ where
{ {
let mut cfg = ServiceConfig::new(); let mut cfg = ServiceConfig::new();
f(&mut cfg); f(&mut cfg);
self.state.extend(cfg.state);
self.services.extend(cfg.services); self.services.extend(cfg.services);
self.external.extend(cfg.external); self.external.extend(cfg.external);
self.extensions.extend(cfg.state);
self self
} }
@ -375,7 +362,6 @@ where
App { App {
filter: self.filter.and_then(filter.into_factory()), filter: self.filter.and_then(filter.into_factory()),
middleware: self.middleware, middleware: self.middleware,
state: self.state,
state_factories: self.state_factories, state_factories: self.state_factories,
services: self.services, services: self.services,
default: self.default, default: self.default,
@ -416,7 +402,6 @@ where
App { App {
middleware: Stack::new(self.middleware, mw), middleware: Stack::new(self.middleware, mw),
filter: self.filter, filter: self.filter,
state: self.state,
state_factories: self.state_factories, state_factories: self.state_factories,
services: self.services, services: self.services,
default: self.default, default: self.default,
@ -508,7 +493,6 @@ where
let app = AppFactory { let app = AppFactory {
filter: self.filter, filter: self.filter,
middleware: Rc::new(self.middleware), middleware: Rc::new(self.middleware),
state: Rc::new(self.state),
state_factories: Rc::new(self.state_factories), state_factories: Rc::new(self.state_factories),
services: Rc::new(RefCell::new(self.services)), services: Rc::new(RefCell::new(self.services)),
external: RefCell::new(self.external), external: RefCell::new(self.external),
@ -538,7 +522,6 @@ where
AppFactory { AppFactory {
filter: self.filter, filter: self.filter,
middleware: Rc::new(self.middleware), middleware: Rc::new(self.middleware),
state: Rc::new(self.state),
state_factories: Rc::new(self.state_factories), state_factories: Rc::new(self.state_factories),
services: Rc::new(RefCell::new(self.services)), services: Rc::new(RefCell::new(self.services)),
external: RefCell::new(self.external), external: RefCell::new(self.external),
@ -566,7 +549,6 @@ where
AppFactory { AppFactory {
filter: self.filter, filter: self.filter,
middleware: Rc::new(self.middleware), middleware: Rc::new(self.middleware),
state: Rc::new(self.state),
state_factories: Rc::new(self.state_factories), state_factories: Rc::new(self.state_factories),
services: Rc::new(RefCell::new(self.services)), services: Rc::new(RefCell::new(self.services)),
external: RefCell::new(self.external), external: RefCell::new(self.external),
@ -729,12 +711,12 @@ mod tests {
#[crate::rt_test] #[crate::rt_test]
async fn test_extension() { async fn test_extension() {
let srv = init_service(App::new().app_state(10usize).service( let srv = init_service(App::new().state(10usize).service(web::resource("/").to(
web::resource("/").to(|req: HttpRequest| async move { |req: HttpRequest| async move {
assert_eq!(*req.app_state::<usize>().unwrap(), 10); assert_eq!(*req.app_state::<usize>().unwrap(), 10);
HttpResponse::Ok() HttpResponse::Ok()
}), },
)) )))
.await; .await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap(); let resp = srv.call(req).await.unwrap();

View file

@ -14,8 +14,7 @@ use super::httprequest::{HttpRequest, HttpRequestPool};
use super::request::WebRequest; use super::request::WebRequest;
use super::response::WebResponse; use super::response::WebResponse;
use super::rmap::ResourceMap; use super::rmap::ResourceMap;
use super::service::{AppServiceFactory, WebServiceConfig}; use super::service::{AppServiceFactory, AppState, WebServiceConfig};
use super::types::state::StateFactory;
type Guards = Vec<Box<dyn Guard>>; type Guards = Vec<Box<dyn Guard>>;
type HttpService<Err: ErrorRenderer> = type HttpService<Err: ErrorRenderer> =
@ -25,7 +24,7 @@ type HttpNewService<Err: ErrorRenderer> =
type BoxResponse<Err: ErrorRenderer> = type BoxResponse<Err: ErrorRenderer> =
Pin<Box<dyn Future<Output = Result<WebResponse, Err::Container>>>>; Pin<Box<dyn Future<Output = Result<WebResponse, Err::Container>>>>;
type FnStateFactory = type FnStateFactory =
Box<dyn Fn() -> Pin<Box<dyn Future<Output = Result<Box<dyn StateFactory>, ()>>>>>; Box<dyn Fn(Extensions) -> Pin<Box<dyn Future<Output = Result<Extensions, ()>>>>>;
/// Service factory to convert `Request` to a `WebRequest<S>`. /// Service factory to convert `Request` to a `WebRequest<S>`.
/// It also executes state factories. /// It also executes state factories.
@ -43,7 +42,6 @@ where
pub(super) middleware: Rc<T>, pub(super) middleware: Rc<T>,
pub(super) filter: PipelineFactory<F, WebRequest<Err>>, pub(super) filter: PipelineFactory<F, WebRequest<Err>>,
pub(super) extensions: RefCell<Option<Extensions>>, pub(super) extensions: RefCell<Option<Extensions>>,
pub(super) state: Rc<Vec<Box<dyn StateFactory>>>,
pub(super) state_factories: Rc<Vec<FnStateFactory>>, pub(super) state_factories: Rc<Vec<FnStateFactory>>,
pub(super) services: Rc<RefCell<Vec<Box<dyn AppServiceFactory<Err>>>>>, pub(super) services: Rc<RefCell<Vec<Box<dyn AppServiceFactory<Err>>>>>,
pub(super) default: Option<Rc<HttpNewService<Err>>>, pub(super) default: Option<Rc<HttpNewService<Err>>>,
@ -95,6 +93,8 @@ where
type Future = Pin<Box<dyn Future<Output = Result<Self::Service, Self::InitError>>>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Service, Self::InitError>>>>;
fn new_service(&self, config: AppConfig) -> Self::Future { fn new_service(&self, config: AppConfig) -> Self::Future {
let services = std::mem::take(&mut *self.services.borrow_mut());
// update resource default service // update resource default service
let default = self.default.clone().unwrap_or_else(|| { let default = self.default.clone().unwrap_or_else(|| {
Rc::new(boxed::factory(fn_service( Rc::new(boxed::factory(fn_service(
@ -104,42 +104,7 @@ where
))) )))
}); });
// App config
let mut config = WebServiceConfig::new(config, default.clone(), self.state.clone());
// register services
std::mem::take(&mut *self.services.borrow_mut())
.into_iter()
.for_each(|mut srv| srv.register(&mut config));
let (config, services) = config.into_services();
// resource map
let mut rmap = ResourceMap::new(ResourceDef::new(""));
for mut rdef in std::mem::take(&mut *self.external.borrow_mut()) {
rmap.add(&mut rdef, None);
}
// complete pipeline creation
let services: Vec<_> = services
.into_iter()
.map(|(mut rdef, srv, guards, nested)| {
rmap.add(&mut rdef, nested);
(rdef, srv, RefCell::new(guards))
})
.collect();
let default_fut = default.new_service(());
let mut router = Router::build();
if self.case_insensitive {
router.case_insensitive();
}
// complete ResourceMap tree creation
let rmap = Rc::new(rmap);
rmap.finish(rmap.clone());
let filter_fut = self.filter.new_service(()); let filter_fut = self.filter.new_service(());
let state = self.state.clone();
let state_factories = self.state_factories.clone(); let state_factories = self.state_factories.clone();
let mut extensions = self let mut extensions = self
.extensions .extensions
@ -147,8 +112,48 @@ where
.take() .take()
.unwrap_or_else(Extensions::new); .unwrap_or_else(Extensions::new);
let middleware = self.middleware.clone(); let middleware = self.middleware.clone();
let external = std::mem::take(&mut *self.external.borrow_mut());
let mut router = Router::build();
if self.case_insensitive {
router.case_insensitive();
}
Box::pin(async move { Box::pin(async move {
// app state factories
for fut in state_factories.iter() {
extensions = fut(extensions).await?;
}
let state = AppState::new(extensions, None, config);
// App config
let mut config = WebServiceConfig::new(state.clone(), default.clone());
// register services
services
.into_iter()
.for_each(|mut srv| srv.register(&mut config));
let services = config.into_services();
// resource map
let mut rmap = ResourceMap::new(ResourceDef::new(""));
for mut rdef in external {
rmap.add(&mut rdef, None);
}
// complete pipeline creation
let services: Vec<_> = services
.into_iter()
.map(|(mut rdef, srv, guards, nested)| {
rmap.add(&mut rdef, nested);
(rdef, srv, RefCell::new(guards))
})
.collect();
// complete ResourceMap tree creation
let rmap = Rc::new(rmap);
rmap.finish(rmap.clone());
// create http services // create http services
for (path, factory, guards) in &mut services.iter() { for (path, factory, guards) in &mut services.iter() {
let service = factory.new_service(()).await?; let service = factory.new_service(()).await?;
@ -157,7 +162,7 @@ where
let routing = AppRouting { let routing = AppRouting {
router: router.finish(), router: router.finish(),
default: Some(default_fut.await?), default: Some(default.new_service(()).await?),
}; };
// main service // main service
@ -166,23 +171,10 @@ where
routing: Rc::new(routing), routing: Rc::new(routing),
}; };
// create app state container
for f in state.iter() {
f.create(&mut extensions);
}
// async state factories
for fut in state_factories.iter() {
if let Ok(f) = fut().await {
f.create(&mut extensions);
}
}
Ok(AppFactoryService { Ok(AppFactoryService {
rmap, rmap,
config, state,
service: middleware.new_transform(service), service: middleware.new_transform(service),
state: Rc::new(extensions),
pool: HttpRequestPool::create(), pool: HttpRequestPool::create(),
_t: PhantomData, _t: PhantomData,
}) })
@ -198,8 +190,7 @@ where
{ {
service: T, service: T,
rmap: Rc<ResourceMap>, rmap: Rc<ResourceMap>,
config: AppConfig, state: AppState,
state: Rc<Extensions>,
pool: &'static HttpRequestPool, pool: &'static HttpRequestPool,
_t: PhantomData<Err>, _t: PhantomData<Err>,
} }
@ -239,7 +230,6 @@ where
head, head,
payload, payload,
self.rmap.clone(), self.rmap.clone(),
self.config.clone(),
self.state.clone(), self.state.clone(),
self.pool, self.pool,
) )

View file

@ -1,17 +1,17 @@
use std::{net::SocketAddr, rc::Rc}; use std::{net::SocketAddr, rc::Rc};
use crate::router::ResourceDef; use crate::{router::ResourceDef, util::Extensions};
use super::resource::Resource; use super::resource::Resource;
use super::route::Route; use super::route::Route;
use super::service::{AppServiceFactory, ServiceFactoryWrapper, WebServiceFactory}; use super::service::{AppServiceFactory, ServiceFactoryWrapper, WebServiceFactory};
use super::types::state::{State, StateFactory};
use super::{DefaultError, ErrorRenderer}; use super::{DefaultError, ErrorRenderer};
/// Application configuration /// Application configuration
#[derive(Clone)] #[derive(Debug, Clone)]
pub struct AppConfig(Rc<AppConfigInner>); pub struct AppConfig(Rc<AppConfigInner>);
#[derive(Debug)]
struct AppConfigInner { struct AppConfigInner {
secure: bool, secure: bool,
host: String, host: String,
@ -61,7 +61,7 @@ impl Default for AppConfig {
/// modularization of big application configuration. /// modularization of big application configuration.
pub struct ServiceConfig<Err = DefaultError> { pub struct ServiceConfig<Err = DefaultError> {
pub(super) services: Vec<Box<dyn AppServiceFactory<Err>>>, pub(super) services: Vec<Box<dyn AppServiceFactory<Err>>>,
pub(super) state: Vec<Box<dyn StateFactory>>, pub(super) state: Extensions,
pub(super) external: Vec<ResourceDef>, pub(super) external: Vec<ResourceDef>,
} }
@ -69,17 +69,16 @@ impl<Err: ErrorRenderer> ServiceConfig<Err> {
pub(crate) fn new() -> Self { pub(crate) fn new() -> Self {
Self { Self {
services: Vec::new(), services: Vec::new(),
state: Vec::new(), state: Extensions::new(),
external: Vec::new(), external: Vec::new(),
} }
} }
/// Set application state. Application state could be accessed /// Set application state.
/// by using `State<T>` extractor where `T` is state type.
/// ///
/// This is same as `App::state()` method. /// This is same as `App::state()` method.
pub fn state<S: 'static>(&mut self, st: S) -> &mut Self { pub fn state<S: 'static>(&mut self, st: S) -> &mut Self {
self.state.push(Box::new(State::new(st))); self.state.insert(st);
self self
} }

View file

@ -12,6 +12,7 @@ use super::error::ErrorRenderer;
use super::extract::FromRequest; use super::extract::FromRequest;
use super::info::ConnectionInfo; use super::info::ConnectionInfo;
use super::rmap::ResourceMap; use super::rmap::ResourceMap;
use super::service::AppState;
#[derive(Clone)] #[derive(Clone)]
/// An HTTP Request /// An HTTP Request
@ -21,9 +22,8 @@ pub(crate) struct HttpRequestInner {
pub(crate) head: Message<RequestHead>, pub(crate) head: Message<RequestHead>,
pub(crate) path: Path<Uri>, pub(crate) path: Path<Uri>,
pub(crate) payload: Payload, pub(crate) payload: Payload,
pub(crate) app_state: Rc<Extensions>, pub(crate) app_state: AppState,
rmap: Rc<ResourceMap>, rmap: Rc<ResourceMap>,
config: AppConfig,
pool: &'static HttpRequestPool, pool: &'static HttpRequestPool,
} }
@ -34,8 +34,7 @@ impl HttpRequest {
head: Message<RequestHead>, head: Message<RequestHead>,
payload: Payload, payload: Payload,
rmap: Rc<ResourceMap>, rmap: Rc<ResourceMap>,
config: AppConfig, app_state: AppState,
app_state: Rc<Extensions>,
pool: &'static HttpRequestPool, pool: &'static HttpRequestPool,
) -> HttpRequest { ) -> HttpRequest {
HttpRequest(Rc::new(HttpRequestInner { HttpRequest(Rc::new(HttpRequestInner {
@ -44,7 +43,6 @@ impl HttpRequest {
payload, payload,
app_state, app_state,
rmap, rmap,
config,
pool, pool,
})) }))
} }
@ -213,7 +211,7 @@ impl HttpRequest {
/// App config /// App config
#[inline] #[inline]
pub fn app_config(&self) -> &AppConfig { pub fn app_config(&self) -> &AppConfig {
&self.0.config self.0.app_state.config()
} }
/// Get an application state object stored with `App::state()` or `App::app_state()` /// Get an application state object stored with `App::state()` or `App::app_state()`
@ -467,22 +465,22 @@ mod tests {
#[crate::rt_test] #[crate::rt_test]
async fn test_state() { async fn test_state() {
let srv = init_service(App::new().app_state(10usize).service( let srv = init_service(App::new().state(10usize).service(web::resource("/").to(
web::resource("/").to(|req: HttpRequest| async move { |req: HttpRequest| async move {
if req.app_state::<usize>().is_some() { if req.app_state::<usize>().is_some() {
HttpResponse::Ok() HttpResponse::Ok()
} else { } else {
HttpResponse::BadRequest() HttpResponse::BadRequest()
} }
}), },
)) )))
.await; .await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
let resp = call_service(&srv, req).await; let resp = call_service(&srv, req).await;
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
let srv = init_service(App::new().app_state(10u32).service(web::resource("/").to( let srv = init_service(App::new().state(10u32).service(web::resource("/").to(
|req: HttpRequest| async move { |req: HttpRequest| async move {
if req.app_state::<usize>().is_some() { if req.app_state::<usize>().is_some() {
HttpResponse::Ok() HttpResponse::Ok()

View file

@ -13,6 +13,7 @@ use super::httprequest::HttpRequest;
use super::info::ConnectionInfo; use super::info::ConnectionInfo;
use super::response::WebResponse; use super::response::WebResponse;
use super::rmap::ResourceMap; use super::rmap::ResourceMap;
use super::service::AppState;
/// An service http request /// An service http request
/// ///
@ -224,8 +225,8 @@ impl<Err> WebRequest<Err> {
#[doc(hidden)] #[doc(hidden)]
/// Set new app state container /// Set new app state container
pub fn set_state_container(&mut self, extensions: Rc<Extensions>) { pub(super) fn set_state_container(&mut self, state: AppState) {
Rc::get_mut(&mut (self.req).0).unwrap().app_state = extensions; Rc::get_mut(&mut (self.req).0).unwrap().app_state = state;
} }
/// Request extensions /// Request extensions

View file

@ -17,7 +17,7 @@ use super::request::WebRequest;
use super::responder::Responder; use super::responder::Responder;
use super::response::WebResponse; use super::response::WebResponse;
use super::route::{IntoRoutes, Route, RouteService}; use super::route::{IntoRoutes, Route, RouteService};
use super::{app::Filter, app::Stack, guard::Guard, types::State}; use super::{app::Filter, app::Stack, guard::Guard, service::AppState};
type HttpService<Err: ErrorRenderer> = type HttpService<Err: ErrorRenderer> =
BoxService<WebRequest<Err>, WebResponse, Err::Container>; BoxService<WebRequest<Err>, WebResponse, Err::Container>;
@ -63,10 +63,10 @@ impl<Err: ErrorRenderer> Resource<Err> {
routes: Vec::new(), routes: Vec::new(),
rdef: path.patterns(), rdef: path.patterns(),
name: None, name: None,
state: None,
middleware: Identity, middleware: Identity,
filter: pipeline_factory(Filter::new()), filter: pipeline_factory(Filter::new()),
guards: Vec::new(), guards: Vec::new(),
state: None,
default: Rc::new(RefCell::new(None)), default: Rc::new(RefCell::new(None)),
} }
} }
@ -123,6 +123,38 @@ where
self self
} }
/// Provide resource specific state. This method allows to add extractor
/// configuration or specific state available via `State<T>` extractor.
/// Provided state is available for all routes registered for the current resource.
/// Resource state overrides state registered by `App::state()` method.
///
/// ```rust
/// use ntex::web::{self, App, FromRequest};
///
/// /// extract text data from request
/// async fn index(body: String) -> String {
/// format!("Body {}!", body)
/// }
///
/// fn main() {
/// let app = App::new().service(
/// web::resource("/index.html")
/// // limit size of the payload
/// .state(web::types::PayloadConfig::new(4096))
/// .route(
/// // register handler
/// web::get().to(index)
/// ));
/// }
/// ```
pub fn state<D: 'static>(mut self, st: D) -> Self {
if self.state.is_none() {
self.state = Some(Extensions::new());
}
self.state.as_mut().unwrap().insert(st);
self
}
/// Register a new route. /// Register a new route.
/// ///
/// ```rust /// ```rust
@ -169,46 +201,6 @@ where
self self
} }
/// Provide resource specific state. This method allows to add extractor
/// configuration or specific state available via `State<T>` extractor.
/// Provided state is available for all routes registered for the current resource.
/// Resource state overrides state registered by `App::state()` method.
///
/// ```rust
/// use ntex::web::{self, App, FromRequest};
///
/// /// extract text data from request
/// async fn index(body: String) -> String {
/// format!("Body {}!", body)
/// }
///
/// fn main() {
/// let app = App::new().service(
/// web::resource("/index.html")
/// // limit size of the payload
/// .app_state(web::types::PayloadConfig::new(4096))
/// .route(
/// web::get()
/// // register handler
/// .to(index)
/// ));
/// }
/// ```
pub fn state<D: 'static>(self, st: D) -> Self {
self.app_state(State::new(st))
}
/// Set or override application state.
///
/// This method overrides state stored with [`App::app_state()`](#method.app_state)
pub fn app_state<D: 'static>(mut self, st: D) -> Self {
if self.state.is_none() {
self.state = Some(Extensions::new());
}
self.state.as_mut().unwrap().insert(st);
self
}
/// Register a new route and add handler. This route matches all requests. /// Register a new route and add handler. This route matches all requests.
/// ///
/// ```rust /// ```rust
@ -269,10 +261,10 @@ where
middleware: self.middleware, middleware: self.middleware,
rdef: self.rdef, rdef: self.rdef,
name: self.name, name: self.name,
state: self.state,
guards: self.guards, guards: self.guards,
routes: self.routes, routes: self.routes,
default: self.default, default: self.default,
state: self.state,
} }
} }
@ -289,10 +281,10 @@ where
filter: self.filter, filter: self.filter,
rdef: self.rdef, rdef: self.rdef,
name: self.name, name: self.name,
state: self.state,
guards: self.guards, guards: self.guards,
routes: self.routes, routes: self.routes,
default: self.default, default: self.default,
state: self.state,
} }
} }
@ -342,14 +334,18 @@ where
if let Some(ref name) = self.name { if let Some(ref name) = self.name {
*rdef.name_mut() = name.clone(); *rdef.name_mut() = name.clone();
} }
// custom app data storage
if let Some(ref mut ext) = self.state { let state = self.state.take().map(|state| {
config.set_service_state(ext); AppState::new(
} state,
Some(config.state().clone()),
config.state().config().clone(),
)
});
let router_factory = ResourceRouterFactory { let router_factory = ResourceRouterFactory {
state,
routes: self.routes, routes: self.routes,
state: self.state.map(Rc::new),
default: self.default, default: self.default,
}; };
@ -386,8 +382,8 @@ where
self, self,
) -> ResourceServiceFactory<Err, M, PipelineFactory<T, WebRequest<Err>>> { ) -> ResourceServiceFactory<Err, M, PipelineFactory<T, WebRequest<Err>>> {
let router_factory = ResourceRouterFactory { let router_factory = ResourceRouterFactory {
state: None,
routes: self.routes, routes: self.routes,
state: self.state.map(Rc::new),
default: self.default, default: self.default,
}; };
@ -506,8 +502,8 @@ where
struct ResourceRouterFactory<Err: ErrorRenderer> { struct ResourceRouterFactory<Err: ErrorRenderer> {
routes: Vec<Route<Err>>, routes: Vec<Route<Err>>,
state: Option<Rc<Extensions>>,
default: Rc<RefCell<Option<Rc<HttpNewService<Err>>>>>, default: Rc<RefCell<Option<Rc<HttpNewService<Err>>>>>,
state: Option<AppState>,
} }
impl<Err: ErrorRenderer> ServiceFactory<WebRequest<Err>> for ResourceRouterFactory<Err> { impl<Err: ErrorRenderer> ServiceFactory<WebRequest<Err>> for ResourceRouterFactory<Err> {
@ -530,8 +526,8 @@ impl<Err: ErrorRenderer> ServiceFactory<WebRequest<Err>> for ResourceRouterFacto
}; };
Ok(ResourceRouter { Ok(ResourceRouter {
routes,
state, state,
routes,
default, default,
}) })
}) })
@ -539,8 +535,8 @@ impl<Err: ErrorRenderer> ServiceFactory<WebRequest<Err>> for ResourceRouterFacto
} }
struct ResourceRouter<Err: ErrorRenderer> { struct ResourceRouter<Err: ErrorRenderer> {
state: Option<AppState>,
routes: Vec<RouteService<Err>>, routes: Vec<RouteService<Err>>,
state: Option<Rc<Extensions>>,
default: Option<HttpService<Err>>, default: Option<HttpService<Err>>,
} }
@ -746,26 +742,22 @@ mod tests {
#[crate::rt_test] #[crate::rt_test]
async fn test_state() { async fn test_state() {
let srv = init_service( let srv = init_service(
App::new() App::new().state(1i32).state(1usize).state('-').service(
.state(1i32) web::resource("/test")
.state(1usize) .state(10usize)
.app_state(web::types::State::new('-')) .state('*')
.service( .guard(guard::Get())
web::resource("/test") .to(
.state(10usize) |data1: web::types::State<usize>,
.app_state(web::types::State::new('*')) data2: web::types::State<char>,
.guard(guard::Get()) data3: web::types::State<i32>| {
.to( assert_eq!(*data1, 10);
|data1: web::types::State<usize>, assert_eq!(*data2, '*');
data2: web::types::State<char>, assert_eq!(*data3, 1);
data3: web::types::State<i32>| { async { HttpResponse::Ok() }
assert_eq!(**data1, 10); },
assert_eq!(**data2, '*'); ),
assert_eq!(**data3, 1); ),
async { HttpResponse::Ok() }
},
),
),
) )
.await; .await;

View file

@ -19,8 +19,7 @@ use super::resource::Resource;
use super::response::WebResponse; use super::response::WebResponse;
use super::rmap::ResourceMap; use super::rmap::ResourceMap;
use super::route::Route; use super::route::Route;
use super::service::{AppServiceFactory, ServiceFactoryWrapper}; use super::service::{AppServiceFactory, AppState, ServiceFactoryWrapper};
use super::types::State;
type Guards = Vec<Box<dyn Guard>>; type Guards = Vec<Box<dyn Guard>>;
type HttpService<Err: ErrorRenderer> = type HttpService<Err: ErrorRenderer> =
@ -149,14 +148,7 @@ where
/// ); /// );
/// } /// }
/// ``` /// ```
pub fn state<D: 'static>(self, st: D) -> Self { pub fn state<D: 'static>(mut self, st: D) -> Self {
self.app_state(State::new(st))
}
/// Set or override application state.
///
/// This method overrides state stored with [`App::app_state()`](#method.app_state)
pub fn app_state<D: 'static>(mut self, st: D) -> Self {
if self.state.is_none() { if self.state.is_none() {
self.state = Some(Extensions::new()); self.state = Some(Extensions::new());
} }
@ -211,11 +203,7 @@ where
if !cfg.state.is_empty() { if !cfg.state.is_empty() {
let mut state = self.state.unwrap_or_else(Extensions::new); let mut state = self.state.unwrap_or_else(Extensions::new);
state.extend(cfg.state);
for value in cfg.state.iter() {
value.create(&mut state);
}
self.state = Some(state); self.state = Some(state);
} }
self self
@ -390,8 +378,16 @@ where
*self.default.borrow_mut() = Some(config.default_service()); *self.default.borrow_mut() = Some(config.default_service());
} }
let state = self.state.take().map(|state| {
AppState::new(
state,
Some(config.state().clone()),
config.state().config().clone(),
)
});
// register nested services // register nested services
let mut cfg = config.clone_config(); let mut cfg = config.clone_config(state.clone());
self.services self.services
.into_iter() .into_iter()
.for_each(|mut srv| srv.register(&mut cfg)); .for_each(|mut srv| srv.register(&mut cfg));
@ -404,19 +400,13 @@ where
rmap.add(&mut rdef, None); rmap.add(&mut rdef, None);
} }
// custom app data storage
if let Some(ref mut ext) = self.state {
config.set_service_state(ext);
}
// complete scope pipeline creation // complete scope pipeline creation
let router_factory = ScopeRouterFactory { let router_factory = ScopeRouterFactory {
state: self.state.take().map(Rc::new), state,
default: self.default.clone(), default: self.default.clone(),
case_insensitive: self.case_insensitive, case_insensitive: self.case_insensitive,
services: Rc::new( services: Rc::new(
cfg.into_services() cfg.into_services()
.1
.into_iter() .into_iter()
.map(|(rdef, srv, guards, nested)| { .map(|(rdef, srv, guards, nested)| {
// case for scope prefix ends with '/' and // case for scope prefix ends with '/' and
@ -560,7 +550,7 @@ where
} }
struct ScopeRouterFactory<Err: ErrorRenderer> { struct ScopeRouterFactory<Err: ErrorRenderer> {
state: Option<Rc<Extensions>>, state: Option<AppState>,
services: Rc<Vec<(ResourceDef, HttpNewService<Err>, RefCell<Option<Guards>>)>>, services: Rc<Vec<(ResourceDef, HttpNewService<Err>, RefCell<Option<Guards>>)>>,
default: Rc<RefCell<Option<Rc<HttpNewService<Err>>>>>, default: Rc<RefCell<Option<Rc<HttpNewService<Err>>>>>,
case_insensitive: bool, case_insensitive: bool,
@ -575,8 +565,8 @@ impl<Err: ErrorRenderer> ServiceFactory<WebRequest<Err>> for ScopeRouterFactory<
fn new_service(&self, _: ()) -> Self::Future { fn new_service(&self, _: ()) -> Self::Future {
let services = self.services.clone(); let services = self.services.clone();
let case_insensitive = self.case_insensitive;
let state = self.state.clone(); let state = self.state.clone();
let case_insensitive = self.case_insensitive;
let default_fut = self let default_fut = self
.default .default
.borrow() .borrow()
@ -610,7 +600,7 @@ impl<Err: ErrorRenderer> ServiceFactory<WebRequest<Err>> for ScopeRouterFactory<
} }
struct ScopeRouter<Err: ErrorRenderer> { struct ScopeRouter<Err: ErrorRenderer> {
state: Option<Rc<Extensions>>, state: Option<AppState>,
router: Router<HttpService<Err>, Vec<Box<dyn Guard>>>, router: Router<HttpService<Err>, Vec<Box<dyn Guard>>>,
default: Option<HttpService<Err>>, default: Option<HttpService<Err>>,
} }
@ -1205,48 +1195,6 @@ mod tests {
); );
} }
#[crate::rt_test]
async fn test_override_data() {
let srv = init_service(App::new().state(1usize).service(
web::scope("app").state(10usize).route(
"/t",
web::get().to(|data: web::types::State<usize>| {
assert_eq!(**data, 10);
async { HttpResponse::Ok() }
}),
),
))
.await;
let req = TestRequest::with_uri("/app/t").to_request();
let resp = call_service(&srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[crate::rt_test]
async fn test_override_app_data() {
let srv = init_service(
App::new()
.app_state(web::types::State::new(1usize))
.service(
web::scope("app")
.app_state(web::types::State::new(10usize))
.route(
"/t",
web::get().to(|data: web::types::State<usize>| {
assert_eq!(**data, 10);
async { HttpResponse::Ok() }
}),
),
),
)
.await;
let req = TestRequest::with_uri("/app/t").to_request();
let resp = call_service(&srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[crate::rt_test] #[crate::rt_test]
async fn test_scope_config() { async fn test_scope_config() {
let srv = init_service(App::new().service(web::scope("/app").configure(|s| { let srv = init_service(App::new().service(web::scope("/app").configure(|s| {
@ -1260,6 +1208,24 @@ mod tests {
assert_eq!(resp.status(), StatusCode::OK); assert_eq!(resp.status(), StatusCode::OK);
} }
#[crate::rt_test]
async fn test_override_state() {
let srv = init_service(App::new().state(1usize).service(
web::scope("app").state(10usize).route(
"/t",
web::get().to(|data: web::types::State<usize>| {
assert_eq!(*data, 10);
async { HttpResponse::Ok() }
}),
),
))
.await;
let req = TestRequest::with_uri("/app/t").to_request();
let resp = call_service(&srv, req).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[crate::rt_test] #[crate::rt_test]
async fn test_scope_config_2() { async fn test_scope_config_2() {
let srv = init_service(App::new().service(web::scope("/app").configure(|s| { let srv = init_service(App::new().service(web::scope("/app").configure(|s| {

View file

@ -8,10 +8,7 @@ use super::config::AppConfig;
use super::dev::insert_slesh; use super::dev::insert_slesh;
use super::error::ErrorRenderer; use super::error::ErrorRenderer;
use super::guard::Guard; use super::guard::Guard;
use super::request::WebRequest; use super::{request::WebRequest, response::WebResponse, rmap::ResourceMap};
use super::response::WebResponse;
use super::rmap::ResourceMap;
use super::types::state::StateFactory;
pub trait WebServiceFactory<Err: ErrorRenderer> { pub trait WebServiceFactory<Err: ErrorRenderer> {
fn register(self, config: &mut WebServiceConfig<Err>); fn register(self, config: &mut WebServiceConfig<Err>);
@ -49,9 +46,58 @@ type Guards = Vec<Box<dyn Guard>>;
type HttpServiceFactory<Err: ErrorRenderer> = type HttpServiceFactory<Err: ErrorRenderer> =
boxed::BoxServiceFactory<(), WebRequest<Err>, WebResponse, Err::Container, ()>; boxed::BoxServiceFactory<(), WebRequest<Err>, WebResponse, Err::Container, ()>;
#[derive(Debug, Clone)]
pub(crate) struct AppState(Rc<AppStateInner>);
#[derive(Debug)]
struct AppStateInner {
ext: Extensions,
parent: Option<AppState>,
config: AppConfig,
}
impl AppState {
pub(crate) fn new(
ext: Extensions,
parent: Option<AppState>,
config: AppConfig,
) -> Self {
AppState(Rc::new(AppStateInner {
ext,
parent,
config,
}))
}
pub(crate) fn config(&self) -> &AppConfig {
&self.0.config
}
pub(crate) fn get<T: 'static>(&self) -> Option<&T> {
let result = self.0.ext.get::<T>();
if result.is_some() {
result
} else if let Some(parent) = self.0.parent.as_ref() {
parent.get::<T>()
} else {
None
}
}
pub(crate) fn contains<T: 'static>(&self) -> bool {
if self.0.ext.contains::<T>() {
true
} else if let Some(parent) = self.0.parent.as_ref() {
parent.contains::<T>()
} else {
false
}
}
}
/// Application service configuration /// Application service configuration
pub struct WebServiceConfig<Err: ErrorRenderer> { pub struct WebServiceConfig<Err: ErrorRenderer> {
config: AppConfig, state: AppState,
root: bool, root: bool,
default: Rc<HttpServiceFactory<Err>>, default: Rc<HttpServiceFactory<Err>>,
services: Vec<( services: Vec<(
@ -60,20 +106,14 @@ pub struct WebServiceConfig<Err: ErrorRenderer> {
Option<Guards>, Option<Guards>,
Option<Rc<ResourceMap>>, Option<Rc<ResourceMap>>,
)>, )>,
service_state: Rc<Vec<Box<dyn StateFactory>>>,
} }
impl<Err: ErrorRenderer> WebServiceConfig<Err> { impl<Err: ErrorRenderer> WebServiceConfig<Err> {
/// Crate server settings instance /// Crate server settings instance
pub(crate) fn new( pub(crate) fn new(state: AppState, default: Rc<HttpServiceFactory<Err>>) -> Self {
config: AppConfig,
default: Rc<HttpServiceFactory<Err>>,
service_state: Rc<Vec<Box<dyn StateFactory>>>,
) -> Self {
WebServiceConfig { WebServiceConfig {
config, state,
default, default,
service_state,
root: true, root: true,
services: Vec::new(), services: Vec::new(),
} }
@ -84,33 +124,33 @@ impl<Err: ErrorRenderer> WebServiceConfig<Err> {
self.root self.root
} }
pub(crate) fn into_services( pub(super) fn state(&self) -> &AppState {
self, &self.state
) -> (
AppConfig,
Vec<(
ResourceDef,
HttpServiceFactory<Err>,
Option<Guards>,
Option<Rc<ResourceMap>>,
)>,
) {
(self.config, self.services)
} }
pub(crate) fn clone_config(&self) -> Self { pub(crate) fn into_services(
self,
) -> Vec<(
ResourceDef,
HttpServiceFactory<Err>,
Option<Guards>,
Option<Rc<ResourceMap>>,
)> {
self.services
}
pub(crate) fn clone_config(&self, state: Option<AppState>) -> Self {
WebServiceConfig { WebServiceConfig {
config: self.config.clone(), state: state.unwrap_or_else(|| self.state.clone()),
default: self.default.clone(), default: self.default.clone(),
services: Vec::new(), services: Vec::new(),
root: false, root: false,
service_state: self.service_state.clone(),
} }
} }
/// Service configuration /// Service configuration
pub fn config(&self) -> &AppConfig { pub fn config(&self) -> &AppConfig {
&self.config self.state.config()
} }
/// Default resource /// Default resource
@ -118,14 +158,6 @@ impl<Err: ErrorRenderer> WebServiceConfig<Err> {
self.default.clone() self.default.clone()
} }
/// Set global route state
pub fn set_service_state(&self, extensions: &mut Extensions) -> bool {
for f in self.service_state.iter() {
f.create(extensions);
}
!self.service_state.is_empty()
}
/// Register http service /// Register http service
pub fn register_service<F, S>( pub fn register_service<F, S>(
&mut self, &mut self,

View file

@ -23,10 +23,10 @@ use crate::util::{stream_recv, Bytes, BytesMut, Extensions, Ready, Stream};
use crate::ws::{error::WsClientError, WsClient, WsConnection}; use crate::ws::{error::WsClientError, WsClient, WsConnection};
use crate::{io::Sealed, rt::System, server::Server}; use crate::{io::Sealed, rt::System, server::Server};
use crate::web::config::AppConfig;
use crate::web::error::{DefaultError, ErrorRenderer}; use crate::web::error::{DefaultError, ErrorRenderer};
use crate::web::httprequest::{HttpRequest, HttpRequestPool}; use crate::web::httprequest::{HttpRequest, HttpRequestPool};
use crate::web::rmap::ResourceMap; use crate::web::rmap::ResourceMap;
use crate::web::{config::AppConfig, service::AppState};
use crate::web::{FromRequest, HttpResponse, Responder, WebRequest, WebResponse}; use crate::web::{FromRequest, HttpResponse, Responder, WebRequest, WebResponse};
/// Create service that always responds with `HttpResponse::Ok()` /// Create service that always responds with `HttpResponse::Ok()`
@ -460,14 +460,14 @@ impl TestRequest {
pub fn to_srv_request(mut self) -> WebRequest<DefaultError> { pub fn to_srv_request(mut self) -> WebRequest<DefaultError> {
let (head, payload) = self.req.finish().into_parts(); let (head, payload) = self.req.finish().into_parts();
*self.path.get_mut() = head.uri.clone(); *self.path.get_mut() = head.uri.clone();
let app_state = AppState::new(self.app_state, None, self.config);
WebRequest::new(HttpRequest::new( WebRequest::new(HttpRequest::new(
self.path, self.path,
head, head,
payload, payload,
Rc::new(self.rmap), Rc::new(self.rmap),
self.config.clone(), app_state,
Rc::new(self.app_state),
HttpRequestPool::create(), HttpRequestPool::create(),
)) ))
} }
@ -481,14 +481,14 @@ impl TestRequest {
pub fn to_http_request(mut self) -> HttpRequest { pub fn to_http_request(mut self) -> HttpRequest {
let (head, payload) = self.req.finish().into_parts(); let (head, payload) = self.req.finish().into_parts();
*self.path.get_mut() = head.uri.clone(); *self.path.get_mut() = head.uri.clone();
let app_state = AppState::new(self.app_state, None, self.config);
HttpRequest::new( HttpRequest::new(
self.path, self.path,
head, head,
payload, payload,
Rc::new(self.rmap), Rc::new(self.rmap),
self.config.clone(), app_state,
Rc::new(self.app_state),
HttpRequestPool::create(), HttpRequestPool::create(),
) )
} }
@ -497,14 +497,14 @@ impl TestRequest {
pub fn to_http_parts(mut self) -> (HttpRequest, Payload) { pub fn to_http_parts(mut self) -> (HttpRequest, Payload) {
let (head, payload) = self.req.finish().into_parts(); let (head, payload) = self.req.finish().into_parts();
*self.path.get_mut() = head.uri.clone(); *self.path.get_mut() = head.uri.clone();
let app_state = AppState::new(self.app_state, None, self.config);
let req = HttpRequest::new( let req = HttpRequest::new(
self.path, self.path,
head, head,
Payload::None, Payload::None,
Rc::new(self.rmap), Rc::new(self.rmap),
self.config.clone(), app_state,
Rc::new(self.app_state),
HttpRequestPool::create(), HttpRequestPool::create(),
); );
@ -980,7 +980,7 @@ mod tests {
.version(Version::HTTP_2) .version(Version::HTTP_2)
.header(header::DATE, "some date") .header(header::DATE, "some date")
.param("test", "123") .param("test", "123")
.state(web::types::State::new(20u64)) .state(20u64)
.peer_addr("127.0.0.1:8081".parse().unwrap()) .peer_addr("127.0.0.1:8081".parse().unwrap())
.to_http_request(); .to_http_request();
assert!(req.headers().contains_key(header::CONTENT_TYPE)); assert!(req.headers().contains_key(header::CONTENT_TYPE));
@ -988,8 +988,8 @@ mod tests {
// assert_eq!(req.peer_addr(), Some("127.0.0.1:8081".parse().unwrap())); // assert_eq!(req.peer_addr(), Some("127.0.0.1:8081".parse().unwrap()));
assert_eq!(&req.match_info()["test"], "123"); assert_eq!(&req.match_info()["test"], "123");
assert_eq!(req.version(), Version::HTTP_2); assert_eq!(req.version(), Version::HTTP_2);
let data = req.app_state::<web::types::State<u64>>().unwrap(); let data = req.app_state::<u64>().unwrap();
assert_eq!(*data.get_ref(), 20); assert_eq!(*data, 20);
assert_eq!(format!("{:?}", StreamType::Tcp), "StreamType::Tcp"); assert_eq!(format!("{:?}", StreamType::Tcp), "StreamType::Tcp");
} }
@ -1154,7 +1154,7 @@ mod tests {
#[crate::rt_test] #[crate::rt_test]
async fn test_server_state() { async fn test_server_state() {
async fn handler(data: web::types::State<usize>) -> crate::http::ResponseBuilder { async fn handler(data: web::types::State<usize>) -> crate::http::ResponseBuilder {
assert_eq!(**data, 10); assert_eq!(*data, 10);
HttpResponse::Ok() HttpResponse::Ok()
} }

View file

@ -172,7 +172,7 @@ where
/// let app = App::new().service( /// let app = App::new().service(
/// web::resource("/index.html") /// web::resource("/index.html")
/// // change `Form` extractor configuration /// // change `Form` extractor configuration
/// .app_state( /// .state(
/// web::types::FormConfig::default().limit(4097) /// web::types::FormConfig::default().limit(4097)
/// ) /// )
/// .route(web::get().to(index)) /// .route(web::get().to(index))

View file

@ -210,7 +210,7 @@ where
/// fn main() { /// fn main() {
/// let app = App::new().service( /// let app = App::new().service(
/// web::resource("/index.html") /// web::resource("/index.html")
/// .app_state( /// .state(
/// // change json extractor configuration /// // change json extractor configuration
/// web::types::JsonConfig::default() /// web::types::JsonConfig::default()
/// .limit(4096) /// .limit(4096)

View file

@ -186,7 +186,7 @@ impl<Err: ErrorRenderer> FromRequest<Err> for Bytes {
/// fn main() { /// fn main() {
/// let app = App::new().service( /// let app = App::new().service(
/// web::resource("/index.html") /// web::resource("/index.html")
/// .app_state( /// .state(
/// web::types::PayloadConfig::new(4096) // <- limit size of the payload /// web::types::PayloadConfig::new(4096) // <- limit size of the payload
/// ) /// )
/// .route(web::get().to(index)) // <- register handler with extractor params /// .route(web::get().to(index)) // <- register handler with extractor params

View file

@ -1,15 +1,10 @@
use std::{ops::Deref, sync::Arc}; use std::{marker::PhantomData, ops::Deref};
use crate::http::Payload;
use crate::util::{Extensions, Ready};
use crate::web::error::{ErrorRenderer, StateExtractorError}; use crate::web::error::{ErrorRenderer, StateExtractorError};
use crate::web::extract::FromRequest; use crate::web::extract::FromRequest;
use crate::web::httprequest::HttpRequest; use crate::web::httprequest::HttpRequest;
use crate::web::service::AppState;
/// Application data factory use crate::{http::Payload, util::Ready};
pub(crate) trait StateFactory {
fn create(&self, extensions: &mut Extensions) -> bool;
}
/// Application state. /// Application state.
/// ///
@ -26,15 +21,13 @@ pub(crate) trait StateFactory {
/// instance for each thread, thus application data must be constructed /// instance for each thread, thus application data must be constructed
/// multiple times. If you want to share state between different /// multiple times. If you want to share state between different
/// threads, a shareable object should be used, e.g. `Send + Sync`. Application /// threads, a shareable object should be used, e.g. `Send + Sync`. Application
/// state does not need to be `Send` or `Sync`. Internally `State` type /// state does not need to be `Send` or `Sync`.
/// uses `Arc`. if your state implements `Send` + `Sync` traits you can
/// use `web::types::State::new()` and avoid double `Arc`.
/// ///
/// If state is not set for a handler, using `State<T>` extractor would /// If state is not set for a handler, using `State<T>` extractor would
/// cause *Internal Server Error* response. /// cause *Internal Server Error* response.
/// ///
/// ```rust /// ```rust
/// use std::sync::Mutex; /// use std::sync::{Arc, Mutex};
/// use ntex::web::{self, App, HttpResponse}; /// use ntex::web::{self, App, HttpResponse};
/// ///
/// struct MyState { /// struct MyState {
@ -42,58 +35,44 @@ pub(crate) trait StateFactory {
/// } /// }
/// ///
/// /// Use `State<T>` extractor to access data in handler. /// /// Use `State<T>` extractor to access data in handler.
/// async fn index(st: web::types::State<Mutex<MyState>>) -> HttpResponse { /// async fn index(st: web::types::State<Arc<Mutex<MyState>>>) -> HttpResponse {
/// let mut data = st.lock().unwrap(); /// let mut data = st.lock().unwrap();
/// data.counter += 1; /// data.counter += 1;
/// HttpResponse::Ok().into() /// HttpResponse::Ok().into()
/// } /// }
/// ///
/// fn main() { /// fn main() {
/// let st = web::types::State::new(Mutex::new(MyState{ counter: 0 })); /// let st = Arc::new(Mutex::new(MyState{ counter: 0 }));
/// ///
/// let app = App::new() /// let app = App::new()
/// // Store `MyState` in application storage. /// // Store `MyState` in application storage.
/// .app_state(st.clone()) /// .state(st.clone())
/// .service( /// .service(
/// web::resource("/index.html").route( /// web::resource("/index.html").route(
/// web::get().to(index))); /// web::get().to(index)));
/// } /// }
/// ``` /// ```
#[derive(Debug)] #[derive(Debug)]
pub struct State<T>(Arc<T>); pub struct State<T>(AppState, PhantomData<T>);
impl<T> State<T> {
/// Create new `State` instance.
///
/// Internally `State` type uses `Arc`. if your state implements
/// `Send` + `Sync` traits you can use `web::types::State::new()` and
/// avoid double `Arc`.
pub fn new(state: T) -> State<T> {
State(Arc::new(state))
}
impl<T: 'static> State<T> {
/// Get reference to inner app data. /// Get reference to inner app data.
pub fn get_ref(&self) -> &T { pub fn get_ref(&self) -> &T {
self.0.as_ref() self.0.get::<T>().expect("Unexpected state")
}
/// Convert to the internal Arc<T>
pub fn into_inner(self) -> Arc<T> {
self.0
} }
} }
impl<T> Deref for State<T> { impl<T: 'static> Deref for State<T> {
type Target = Arc<T>; type Target = T;
fn deref(&self) -> &Arc<T> { fn deref(&self) -> &T {
&self.0 self.get_ref()
} }
} }
impl<T> Clone for State<T> { impl<T> Clone for State<T> {
fn clone(&self) -> State<T> { fn clone(&self) -> State<T> {
State(self.0.clone()) State(self.0.clone(), PhantomData)
} }
} }
@ -103,8 +82,8 @@ impl<T: 'static, E: ErrorRenderer> FromRequest<E> for State<T> {
#[inline] #[inline]
fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future {
if let Some(st) = req.app_state::<State<T>>() { if req.0.app_state.contains::<T>() {
Ready::Ok(st.clone()) Ready::Ok(Self(req.0.app_state.clone(), PhantomData))
} else { } else {
log::debug!( log::debug!(
"Failed to construct App-level State extractor. \ "Failed to construct App-level State extractor. \
@ -116,20 +95,9 @@ impl<T: 'static, E: ErrorRenderer> FromRequest<E> for State<T> {
} }
} }
impl<T: 'static> StateFactory for State<T> {
fn create(&self, extensions: &mut Extensions) -> bool {
if !extensions.contains::<State<T>>() {
extensions.insert(State(self.0.clone()));
true
} else {
false
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{atomic::AtomicUsize, atomic::Ordering, Arc};
use super::*; use super::*;
use crate::http::StatusCode; use crate::http::StatusCode;
@ -138,13 +106,13 @@ mod tests {
use crate::web::{self, App, HttpResponse}; use crate::web::{self, App, HttpResponse};
#[crate::rt_test] #[crate::rt_test]
async fn test_data_extractor() { async fn test_state_extractor() {
let srv = init_service(App::new().state("TEST".to_string()).service( let srv = init_service(
web::resource("/").to(|data: web::types::State<String>| async move { App::new().state(10usize).service(
assert_eq!(data.to_lowercase(), "test"); web::resource("/")
HttpResponse::Ok() .to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
}), ),
)) )
.await; .await;
let req = TestRequest::default().to_request(); let req = TestRequest::default().to_request();
@ -163,78 +131,9 @@ mod tests {
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
} }
#[crate::rt_test]
async fn test_app_data_extractor() {
let srv = init_service(
App::new().app_state(State::new(10usize)).service(
web::resource("/")
.to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
),
)
.await;
let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let srv = init_service(
App::new().app_state(State::new(10u32)).service(
web::resource("/")
.to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
),
)
.await;
let req = TestRequest::default().to_request();
let res = srv.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[crate::rt_test]
async fn test_route_data_extractor() {
let srv =
init_service(App::new().service(web::resource("/").state(10usize).route(
web::get().to(|data: web::types::State<usize>| async move {
let _ = data.clone();
HttpResponse::Ok()
}),
)))
.await;
let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// different type
let srv = init_service(App::new().service(web::resource("/").state(10u32).route(
web::get().to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
)))
.await;
let req = TestRequest::default().to_request();
let res = srv.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[crate::rt_test]
async fn test_override_data() {
let srv = init_service(App::new().state(1usize).service(
web::resource("/").state(10usize).route(web::get().to(
|data: web::types::State<usize>| async move {
assert_eq!(**data, 10);
let _ = data.clone();
HttpResponse::Ok()
},
)),
))
.await;
let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[cfg(feature = "tokio")] #[cfg(feature = "tokio")]
#[crate::rt_test] #[crate::rt_test]
async fn test_data_drop() { async fn test_state_drop() {
struct TestData(Arc<AtomicUsize>); struct TestData(Arc<AtomicUsize>);
impl TestData { impl TestData {
@ -275,4 +174,47 @@ mod tests {
assert_eq!(num.load(Ordering::SeqCst), 0); assert_eq!(num.load(Ordering::SeqCst), 0);
} }
#[crate::rt_test]
async fn test_route_state_extractor() {
let srv =
init_service(App::new().service(web::resource("/").state(10usize).route(
web::get().to(|data: web::types::State<usize>| async move {
let _ = data.clone();
HttpResponse::Ok()
}),
)))
.await;
let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// different type
let srv = init_service(App::new().service(web::resource("/").state(10u32).route(
web::get().to(|_: web::types::State<usize>| async { HttpResponse::Ok() }),
)))
.await;
let req = TestRequest::default().to_request();
let res = srv.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[crate::rt_test]
async fn test_override_state() {
let srv = init_service(App::new().state(1usize).service(
web::resource("/").state(10usize).route(web::get().to(
|data: web::types::State<usize>| async move {
assert_eq!(*data, 10);
let _ = data.clone();
HttpResponse::Ok()
},
)),
))
.await;
let req = TestRequest::default().to_request();
let resp = srv.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
} }

View file

@ -680,7 +680,7 @@ async fn test_brotli_encoding_large() {
let srv = test::server_with(test::config().h1(), || { let srv = test::server_with(test::config().h1(), || {
App::new().service( App::new().service(
web::resource("/") web::resource("/")
.app_state(web::types::PayloadConfig::new(320_000)) .state(web::types::PayloadConfig::new(320_000))
.route(web::to(move |body: Bytes| async { .route(web::to(move |body: Bytes| async {
HttpResponse::Ok().streaming(TestBody::new(body, 10240)) HttpResponse::Ok().streaming(TestBody::new(body, 10240))
})), })),