diff --git a/app/common.py b/app/common.py index 8d048ae..efeaf9b 100644 --- a/app/common.py +++ b/app/common.py @@ -1,16 +1,13 @@ +import os import secrets from pathlib import Path +from dotenv import load_dotenv + from fastapi.templating import Jinja2Templates from pydantic import BaseSettings -class Settings(BaseSettings): - secret_key: str = secrets.token_hex(32) - -settings = Settings() - - file_dir = Path(__file__).parent templates_dir = str( file_dir.parent / 'templates' @@ -18,6 +15,21 @@ templates_dir = str( static_dir = str( file_dir.parent / 'static' ) +debug_env = str( + file_dir.parent / '.env_debug' +) + + +is_debug = bool(os.getenv('DEBUG')) +if is_debug: + load_dotenv(debug_env) + + +class Settings(BaseSettings): + session_key: str = secrets.token_hex(32) + csrf_key: str = secrets.token_hex(32) + +settings = Settings() templates = Jinja2Templates( diff --git a/app/forms/users.py b/app/forms/users.py new file mode 100644 index 0000000..02ae85e --- /dev/null +++ b/app/forms/users.py @@ -0,0 +1,27 @@ +from starlette_wtf import StarletteForm + +from wtforms import IntegerField +from wtforms import StringField, PasswordField + +from wtforms.validators import DataRequired +from wtforms.validators import NumberRange + + +class AddUserForm(StarletteForm): + + pswd = PasswordField('Admin password (1234)') + email = StringField( + label='User\'s e-mail', + validators=[DataRequired()], + ) + name = StringField( + label='User\'s full name', + validators=[DataRequired()], + ) + age = IntegerField( + label='User\'s age', + validators=[ + DataRequired(), + NumberRange(0, 200), + ], + ) diff --git a/app/main.py b/app/main.py index 5624ce6..60f469d 100644 --- a/app/main.py +++ b/app/main.py @@ -1,20 +1,28 @@ from typing import List, Type + from fastapi import FastAPI from fastapi.staticfiles import StaticFiles +from starlette.middleware.sessions import SessionMiddleware +from starlette_wtf import CSRFProtectMiddleware + from . import common +from .sql import db # Add your paths here from .paths.paths import Paths from .paths import pages +from .paths import table from .paths import errors paths: List[Type[Paths]] = [ pages.MainPaths, + table.TablePaths, errors.ErrorsPaths, ] +db.Base.metadata.create_all(bind=db.engine) app = FastAPI() app.mount( '/static', @@ -23,3 +31,12 @@ app.mount( ) for p in paths: p(app).add_paths() + +app.add_middleware( + SessionMiddleware, + secret_key=common.settings.session_key, +) +app.add_middleware( + CSRFProtectMiddleware, + csrf_secret=common.settings.csrf_key, +) diff --git a/app/paths/pages.py b/app/paths/pages.py index a2aa13e..c2e784c 100644 --- a/app/paths/pages.py +++ b/app/paths/pages.py @@ -13,7 +13,7 @@ class MainPaths(paths.Paths): def add_paths(self) -> None: @self.app.get('/') - def index(req: Request) -> Response: + async def index(req: Request) -> Response: return respond.with_tmpl( 'index.html', request=req, diff --git a/app/paths/table.py b/app/paths/table.py new file mode 100644 index 0000000..06d507b --- /dev/null +++ b/app/paths/table.py @@ -0,0 +1,67 @@ +from sqlalchemy.orm import Session + +from fastapi import Depends +from fastapi import Request, Response + +from starlette_wtf import csrf_protect + +from . import paths +from .. import respond +from ..sql import db +from ..sql import crud +from ..sql import schemas +from ..forms.users import AddUserForm + +LIMIT = 10 + + +class TablePaths(paths.Paths): + + def add_paths(self) -> None: + + @self.app.get('/db') + def list_users( + req: Request, + page: int = 0, + db: Session = Depends(db.get_db)) -> Response: + + return respond.with_tmpl( + 'table.html', + request=req, + rows=crud.get_users( + db=db, + skip=(page * LIMIT), + limit=LIMIT, + ), + ) + + @self.app.get('/add') + @self.app.post('/add') + @csrf_protect + async def add_form( + req: Request, + db_s: Session = Depends(db.get_db)) -> Response: + + form = await AddUserForm.from_formdata(request=req) + + if await form.validate_on_submit(): + + if form.pswd.data != '1234': + return respond.with_text('Incorrect password') + + crud.create_user( + db=db_s, + user=schemas.UserCreate( + email=form.email.data, + name=form.name.data, + age=form.age.data or 0, + ), + ) + + return respond.with_redirect('/db') + + return respond.with_tmpl( + 'admin.html', + request=req, + form=form, + ) diff --git a/app/respond.py b/app/respond.py index 279969f..1624f3b 100644 --- a/app/respond.py +++ b/app/respond.py @@ -3,12 +3,71 @@ import mimetypes from typing import Optional, Mapping from fastapi import Response +from fastapi.responses import RedirectResponse +from fastapi.responses import PlainTextResponse from fastapi.responses import FileResponse + from starlette.background import BackgroundTask from .common import templates +def with_redirect( + url: str = '/', + code: int = 302, + headers: Optional[Mapping[str, str]] = None, + background: Optional[BackgroundTask] = None) -> Response: + """Return a redirect to the page specified in `url` + + Args: + url (str, optional): + Target URL (Location header), + root by default + code (int, optional): HTTP response code + headers (Optional[Mapping[str, str]], optional): + Additional headers, passed to Response constructor + background (Optional[BackgroundTask], optional): + Background task, passed to Response constructor + + Returns: + FastAPI's RedirectResponse object + """ + + return RedirectResponse( + url=url, + status_code=code, + headers=headers, + background=background, + ) + + +def with_text( + content: str, + code: int = 200, + headers: Optional[Mapping[str, str]] = None, + background: Optional[BackgroundTask] = None) -> Response: + """Return a plain text to the user + + Args: + content (str): Plain text content + code (int, optional): HTTP response code + headers (Optional[Mapping[str, str]], optional): + Additional headers, passed to Response constructor + background (Optional[BackgroundTask], optional): + Background task, passed to Response constructor + + Returns: + FastAPI's PlainTextResponse object + """ + + return PlainTextResponse( + content=content, + status_code=code, + headers=headers, + background=background, + ) + + def with_tmpl( name: str, code: int = 200, diff --git a/app/sql/crud.py b/app/sql/crud.py index 207be6d..96e7420 100644 --- a/app/sql/crud.py +++ b/app/sql/crud.py @@ -1,4 +1,5 @@ from typing import Optional, List + from sqlalchemy.orm import Session from . import models @@ -18,7 +19,7 @@ def get_user( def get_users( db: Session, skip: int = 0, - limit: int = 100) -> List[Optional[models.User]]: + limit: int = 100) -> List[models.User]: return db \ .query(models.User) \ @@ -29,7 +30,7 @@ def get_users( def create_user( db: Session, - user: schemas.User) -> models.User: + user: schemas.UserCreate) -> models.User: user_model = models.User(**user.dict()) db.add(user_model) diff --git a/app/sql/db.py b/app/sql/db.py index 17d970d..32a32b3 100644 --- a/app/sql/db.py +++ b/app/sql/db.py @@ -1,7 +1,10 @@ -from typing import Generator +from typing import AsyncGenerator + from pydantic import BaseSettings from sqlalchemy import create_engine +from sqlalchemy_utils import database_exists +from sqlalchemy_utils import create_database from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.ext.declarative import declarative_base @@ -15,22 +18,25 @@ class SqlSettings(BaseSettings): sql_settings = SqlSettings() - db_url = ( 'mysql://{db_user}:{db_password}@' - '{db_host}:{db_port}/${db_database}' + '{db_host}:{db_port}/{db_database}' ).format(**sql_settings.dict()) + engine = create_engine(db_url) +if not database_exists(db_url): + create_database(db_url) + SessionLocal = sessionmaker( - autocommit=False, autoflush=False, bind=engine, ) + Base = declarative_base() -def get_db() -> Generator[Session, None, None]: +async def get_db() -> AsyncGenerator[Session, None]: """FastAPI dependency returning database Session object. Code is copied from the official docs diff --git a/app/sql/models.py b/app/sql/models.py index abdbe7f..a654a63 100644 --- a/app/sql/models.py +++ b/app/sql/models.py @@ -7,6 +7,6 @@ class User(Base): __tablename__ = 'users' id = Column(Integer, primary_key=True) - email = Column(String) - name = Column(String) + email = Column(String(32)) + name = Column(String(32)) age = Column(Integer) diff --git a/app/sql/schemas.py b/app/sql/schemas.py index 294775e..1d48d51 100644 --- a/app/sql/schemas.py +++ b/app/sql/schemas.py @@ -1,11 +1,14 @@ from pydantic import BaseModel -class User(BaseModel): - id: int +class UserCreate(BaseModel): email: str name: str age: int + +class User(UserCreate): + id: int + class Config: orm_mode = True diff --git a/requirements.txt b/requirements.txt index 9d933d0..ad31c53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,7 @@ fastapi==0.92.0 uvicorn[standard]==0.20.0 jinja2==3.1.2 starlette-wtf==0.4.3 +sqlalchemy==2.0.4 +sqlalchemy-utils==0.40.0 +mysqlclient==2.1.1 python-dotenv==0.21.1 diff --git a/static/css/style.css b/static/css/style.css index cdaca35..ad391ae 100644 --- a/static/css/style.css +++ b/static/css/style.css @@ -58,3 +58,11 @@ form > div > input:hover, form > div > input:focus { filter: brightness(130%); } + +table { + border-collapse: collapse; +} +td { + border: 1px solid var(--fg); + padding: 5px; +} diff --git a/templates/admin.html b/templates/admin.html index 219eb96..0223fe4 100644 --- a/templates/admin.html +++ b/templates/admin.html @@ -5,7 +5,7 @@ {% block content %}