WTForms, SQLAlchemy

This commit is contained in:
DarkCat09 2023-02-20 12:09:43 +04:00
parent 5f328d82d3
commit 03e4c63d38
14 changed files with 227 additions and 22 deletions

View file

@ -1,16 +1,13 @@
import os
import secrets
from pathlib import Path
from dotenv import load_dotenv
from fastapi.templating import Jinja2Templates
from pydantic import BaseSettings
class Settings(BaseSettings):
secret_key: str = secrets.token_hex(32)
settings = Settings()
file_dir = Path(__file__).parent
templates_dir = str(
file_dir.parent / 'templates'
@ -18,6 +15,21 @@ templates_dir = str(
static_dir = str(
file_dir.parent / 'static'
)
debug_env = str(
file_dir.parent / '.env_debug'
)
is_debug = bool(os.getenv('DEBUG'))
if is_debug:
load_dotenv(debug_env)
class Settings(BaseSettings):
session_key: str = secrets.token_hex(32)
csrf_key: str = secrets.token_hex(32)
settings = Settings()
templates = Jinja2Templates(

27
app/forms/users.py Normal file
View file

@ -0,0 +1,27 @@
from starlette_wtf import StarletteForm
from wtforms import IntegerField
from wtforms import StringField, PasswordField
from wtforms.validators import DataRequired
from wtforms.validators import NumberRange
class AddUserForm(StarletteForm):
pswd = PasswordField('Admin password (1234)')
email = StringField(
label='User\'s e-mail',
validators=[DataRequired()],
)
name = StringField(
label='User\'s full name',
validators=[DataRequired()],
)
age = IntegerField(
label='User\'s age',
validators=[
DataRequired(),
NumberRange(0, 200),
],
)

View file

@ -1,20 +1,28 @@
from typing import List, Type
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from starlette.middleware.sessions import SessionMiddleware
from starlette_wtf import CSRFProtectMiddleware
from . import common
from .sql import db
# Add your paths here
from .paths.paths import Paths
from .paths import pages
from .paths import table
from .paths import errors
paths: List[Type[Paths]] = [
pages.MainPaths,
table.TablePaths,
errors.ErrorsPaths,
]
db.Base.metadata.create_all(bind=db.engine)
app = FastAPI()
app.mount(
'/static',
@ -23,3 +31,12 @@ app.mount(
)
for p in paths:
p(app).add_paths()
app.add_middleware(
SessionMiddleware,
secret_key=common.settings.session_key,
)
app.add_middleware(
CSRFProtectMiddleware,
csrf_secret=common.settings.csrf_key,
)

View file

@ -13,7 +13,7 @@ class MainPaths(paths.Paths):
def add_paths(self) -> None:
@self.app.get('/')
def index(req: Request) -> Response:
async def index(req: Request) -> Response:
return respond.with_tmpl(
'index.html',
request=req,

67
app/paths/table.py Normal file
View file

@ -0,0 +1,67 @@
from sqlalchemy.orm import Session
from fastapi import Depends
from fastapi import Request, Response
from starlette_wtf import csrf_protect
from . import paths
from .. import respond
from ..sql import db
from ..sql import crud
from ..sql import schemas
from ..forms.users import AddUserForm
LIMIT = 10
class TablePaths(paths.Paths):
def add_paths(self) -> None:
@self.app.get('/db')
def list_users(
req: Request,
page: int = 0,
db: Session = Depends(db.get_db)) -> Response:
return respond.with_tmpl(
'table.html',
request=req,
rows=crud.get_users(
db=db,
skip=(page * LIMIT),
limit=LIMIT,
),
)
@self.app.get('/add')
@self.app.post('/add')
@csrf_protect
async def add_form(
req: Request,
db_s: Session = Depends(db.get_db)) -> Response:
form = await AddUserForm.from_formdata(request=req)
if await form.validate_on_submit():
if form.pswd.data != '1234':
return respond.with_text('Incorrect password')
crud.create_user(
db=db_s,
user=schemas.UserCreate(
email=form.email.data,
name=form.name.data,
age=form.age.data or 0,
),
)
return respond.with_redirect('/db')
return respond.with_tmpl(
'admin.html',
request=req,
form=form,
)

View file

@ -3,12 +3,71 @@ import mimetypes
from typing import Optional, Mapping
from fastapi import Response
from fastapi.responses import RedirectResponse
from fastapi.responses import PlainTextResponse
from fastapi.responses import FileResponse
from starlette.background import BackgroundTask
from .common import templates
def with_redirect(
url: str = '/',
code: int = 302,
headers: Optional[Mapping[str, str]] = None,
background: Optional[BackgroundTask] = None) -> Response:
"""Return a redirect to the page specified in `url`
Args:
url (str, optional):
Target URL (Location header),
root by default
code (int, optional): HTTP response code
headers (Optional[Mapping[str, str]], optional):
Additional headers, passed to Response constructor
background (Optional[BackgroundTask], optional):
Background task, passed to Response constructor
Returns:
FastAPI's RedirectResponse object
"""
return RedirectResponse(
url=url,
status_code=code,
headers=headers,
background=background,
)
def with_text(
content: str,
code: int = 200,
headers: Optional[Mapping[str, str]] = None,
background: Optional[BackgroundTask] = None) -> Response:
"""Return a plain text to the user
Args:
content (str): Plain text content
code (int, optional): HTTP response code
headers (Optional[Mapping[str, str]], optional):
Additional headers, passed to Response constructor
background (Optional[BackgroundTask], optional):
Background task, passed to Response constructor
Returns:
FastAPI's PlainTextResponse object
"""
return PlainTextResponse(
content=content,
status_code=code,
headers=headers,
background=background,
)
def with_tmpl(
name: str,
code: int = 200,

View file

@ -1,4 +1,5 @@
from typing import Optional, List
from sqlalchemy.orm import Session
from . import models
@ -18,7 +19,7 @@ def get_user(
def get_users(
db: Session,
skip: int = 0,
limit: int = 100) -> List[Optional[models.User]]:
limit: int = 100) -> List[models.User]:
return db \
.query(models.User) \
@ -29,7 +30,7 @@ def get_users(
def create_user(
db: Session,
user: schemas.User) -> models.User:
user: schemas.UserCreate) -> models.User:
user_model = models.User(**user.dict())
db.add(user_model)

View file

@ -1,7 +1,10 @@
from typing import Generator
from typing import AsyncGenerator
from pydantic import BaseSettings
from sqlalchemy import create_engine
from sqlalchemy_utils import database_exists
from sqlalchemy_utils import create_database
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
@ -15,22 +18,25 @@ class SqlSettings(BaseSettings):
sql_settings = SqlSettings()
db_url = (
'mysql://{db_user}:{db_password}@'
'{db_host}:{db_port}/${db_database}'
'{db_host}:{db_port}/{db_database}'
).format(**sql_settings.dict())
engine = create_engine(db_url)
if not database_exists(db_url):
create_database(db_url)
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
)
Base = declarative_base()
def get_db() -> Generator[Session, None, None]:
async def get_db() -> AsyncGenerator[Session, None]:
"""FastAPI dependency
returning database Session object.
Code is copied from the official docs

View file

@ -7,6 +7,6 @@ class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True)
email = Column(String)
name = Column(String)
email = Column(String(32))
name = Column(String(32))
age = Column(Integer)

View file

@ -1,11 +1,14 @@
from pydantic import BaseModel
class User(BaseModel):
id: int
class UserCreate(BaseModel):
email: str
name: str
age: int
class User(UserCreate):
id: int
class Config:
orm_mode = True

View file

@ -2,4 +2,7 @@ fastapi==0.92.0
uvicorn[standard]==0.20.0
jinja2==3.1.2
starlette-wtf==0.4.3
sqlalchemy==2.0.4
sqlalchemy-utils==0.40.0
mysqlclient==2.1.1
python-dotenv==0.21.1

View file

@ -58,3 +58,11 @@ form > div > input:hover,
form > div > input:focus {
filter: brightness(130%);
}
table {
border-collapse: collapse;
}
td {
border: 1px solid var(--fg);
padding: 5px;
}

View file

@ -5,7 +5,7 @@
{% block content %}
<h1>Add a person to DB</h1>
<form action="/add" method="post">
{{ form.hidden_tag() }}
{{ form.csrf_token }}
{% for field in form %}
{% if field.name != 'csrf_token' %}
<div>

View file

@ -4,13 +4,15 @@
{% block content %}
<h1>Sample database</h1>
<p><a href="/add">Add a user</a></p>
<table>
<tbody>
{% for row in rows %}
<tr>
{% for cell in row %}
<td>{{ cell }}</td>
{% endfor %}
<td>{{ row.id }}</td>
<td>{{ row.email }}</td>
<td>{{ row.name }}</td>
<td>{{ row.age }}</td>
</tr>
{% endfor %}
</tbody>