sql.db -> sql(.init)

This commit is contained in:
DarkCat09 2023-02-27 20:11:55 +04:00
parent a91613cbd3
commit 8311d12d20
5 changed files with 70 additions and 70 deletions

View file

@ -7,7 +7,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette_wtf import CSRFProtectMiddleware from starlette_wtf import CSRFProtectMiddleware
from . import common from . import common
from .sql import db from . import sql
from .paths import Paths from .paths import Paths
# Add your paths below # Add your paths below
@ -23,7 +23,7 @@ paths: List[Type[Paths]] = [
# Initialize SQL database # Initialize SQL database
db.Base.metadata.create_all(bind=db.engine) sql.Base.metadata.create_all(bind=sql.engine)
# Create app # Create app
app = FastAPI() app = FastAPI()

View file

@ -9,7 +9,7 @@ from starlette_wtf import csrf_protect
from . import Paths from . import Paths
from .. import respond from .. import respond
from ..sql import db from .. import sql
from ..sql import crud from ..sql import crud
from ..sql import schemas from ..sql import schemas
from ..forms import get_form from ..forms import get_form
@ -26,7 +26,7 @@ class TablePaths(Paths):
def list_users( def list_users(
req: Request, req: Request,
page: int = 0, page: int = 0,
db: Session = Depends(db.get_db)) -> Response: db: Session = Depends(sql.get_db)) -> Response:
return respond.with_tmpl( return respond.with_tmpl(
'table.html', 'table.html',
@ -43,7 +43,7 @@ class TablePaths(Paths):
@csrf_protect @csrf_protect
async def add_form( async def add_form(
req: Request, req: Request,
db_s: Session = Depends(db.get_db)) -> Response: db_s: Session = Depends(sql.get_db)) -> Response:
form = await get_form(AddUserForm, req) form = await get_form(AddUserForm, req)

View file

@ -0,0 +1,64 @@
from typing import AsyncGenerator
from pydantic import BaseSettings
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import database_exists
from sqlalchemy_utils import create_database
# Database configuration
class SqlSettings(BaseSettings):
db_host: str = '${REPO_NAME_SNAKE}_db'
db_port: int = 3306
db_user: str = '${REPO_NAME_SNAKE}'
db_password: str = ''
db_database: str = '${REPO_NAME_SNAKE}'
sql_settings = SqlSettings()
# DB connection URL
# pylint: disable=consider-using-f-string
db_url = (
'mysql://{db_user}:{db_password}@'
'{db_host}:{db_port}/{db_database}'
).format(**sql_settings.dict())
# pylint: enable=consider-using-f-string
# SQLAlchemy engine object
engine = create_engine(db_url)
# Create DB if not exists
if not database_exists(db_url):
create_database(db_url)
# SQLAlchemy Session object
SessionLocal = sessionmaker(
autoflush=False,
bind=engine,
)
# SQLAlchemy Base object
Base = declarative_base()
# FastAPI dependency
async def get_db() -> AsyncGenerator[Session, None]:
"""FastAPI dependency
returning database Session object.
Code is copied from the official docs
Yields:
SQLAlchemy Session object
"""
db = SessionLocal()
try:
yield db
finally:
db.close()

View file

@ -1,64 +0,0 @@
from typing import AsyncGenerator
from pydantic import BaseSettings
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import database_exists
from sqlalchemy_utils import create_database
# Database configuration
class SqlSettings(BaseSettings):
db_host: str = '${REPO_NAME_SNAKE}_db'
db_port: int = 3306
db_user: str = '${REPO_NAME_SNAKE}'
db_password: str = ''
db_database: str = '${REPO_NAME_SNAKE}'
sql_settings = SqlSettings()
# DB connection URL
# pylint: disable=consider-using-f-string
db_url = (
'mysql://{db_user}:{db_password}@'
'{db_host}:{db_port}/{db_database}'
).format(**sql_settings.dict())
# pylint: enable=consider-using-f-string
# SQLAlchemy engine object
engine = create_engine(db_url)
# Create DB if not exists
if not database_exists(db_url):
create_database(db_url)
# SQLAlchemy Session object
SessionLocal = sessionmaker(
autoflush=False,
bind=engine,
)
# SQLAlchemy Base object
Base = declarative_base()
# FastAPI dependency
async def get_db() -> AsyncGenerator[Session, None]:
"""FastAPI dependency
returning database Session object.
Code is copied from the official docs
Yields:
SQLAlchemy Session object
"""
db = SessionLocal()
try:
yield db
finally:
db.close()

View file

@ -1,6 +1,6 @@
from sqlalchemy import Column, String, Integer from sqlalchemy import Column, String, Integer
from .db import Base from . import Base
class User(Base): class User(Base):