diff --git a/app/main.py b/app/main.py index f409f7d..cb7a407 100644 --- a/app/main.py +++ b/app/main.py @@ -7,7 +7,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette_wtf import CSRFProtectMiddleware from . import common -from .sql import db +from . import sql from .paths import Paths # Add your paths below @@ -23,7 +23,7 @@ paths: List[Type[Paths]] = [ # Initialize SQL database -db.Base.metadata.create_all(bind=db.engine) +sql.Base.metadata.create_all(bind=sql.engine) # Create app app = FastAPI() diff --git a/app/paths/table.py b/app/paths/table.py index e74f9cf..d5dbf9c 100644 --- a/app/paths/table.py +++ b/app/paths/table.py @@ -9,7 +9,7 @@ from starlette_wtf import csrf_protect from . import Paths from .. import respond -from ..sql import db +from .. import sql from ..sql import crud from ..sql import schemas from ..forms import get_form @@ -26,7 +26,7 @@ class TablePaths(Paths): def list_users( req: Request, page: int = 0, - db: Session = Depends(db.get_db)) -> Response: + db: Session = Depends(sql.get_db)) -> Response: return respond.with_tmpl( 'table.html', @@ -43,7 +43,7 @@ class TablePaths(Paths): @csrf_protect async def add_form( req: Request, - db_s: Session = Depends(db.get_db)) -> Response: + db_s: Session = Depends(sql.get_db)) -> Response: form = await get_form(AddUserForm, req) diff --git a/app/sql/__init__.py b/app/sql/__init__.py index e69de29..3a03dea 100644 --- a/app/sql/__init__.py +++ b/app/sql/__init__.py @@ -0,0 +1,64 @@ +from typing import AsyncGenerator + +from pydantic import BaseSettings + +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import database_exists +from sqlalchemy_utils import create_database + + +# Database configuration +class SqlSettings(BaseSettings): + db_host: str = '${REPO_NAME_SNAKE}_db' + db_port: int = 3306 + db_user: str = '${REPO_NAME_SNAKE}' + db_password: str = '' + db_database: str = '${REPO_NAME_SNAKE}' + + +sql_settings = SqlSettings() + +# DB connection URL +# pylint: disable=consider-using-f-string +db_url = ( + 'mysql://{db_user}:{db_password}@' + '{db_host}:{db_port}/{db_database}' +).format(**sql_settings.dict()) +# pylint: enable=consider-using-f-string + + +# SQLAlchemy engine object +engine = create_engine(db_url) + +# Create DB if not exists +if not database_exists(db_url): + create_database(db_url) + +# SQLAlchemy Session object +SessionLocal = sessionmaker( + autoflush=False, + bind=engine, +) + +# SQLAlchemy Base object +Base = declarative_base() + + +# FastAPI dependency +async def get_db() -> AsyncGenerator[Session, None]: + """FastAPI dependency + returning database Session object. + Code is copied from the official docs + + Yields: + SQLAlchemy Session object + """ + + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/app/sql/db.py b/app/sql/db.py deleted file mode 100644 index 3a03dea..0000000 --- a/app/sql/db.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import AsyncGenerator - -from pydantic import BaseSettings - -from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker -from sqlalchemy.ext.declarative import declarative_base - -from sqlalchemy_utils import database_exists -from sqlalchemy_utils import create_database - - -# Database configuration -class SqlSettings(BaseSettings): - db_host: str = '${REPO_NAME_SNAKE}_db' - db_port: int = 3306 - db_user: str = '${REPO_NAME_SNAKE}' - db_password: str = '' - db_database: str = '${REPO_NAME_SNAKE}' - - -sql_settings = SqlSettings() - -# DB connection URL -# pylint: disable=consider-using-f-string -db_url = ( - 'mysql://{db_user}:{db_password}@' - '{db_host}:{db_port}/{db_database}' -).format(**sql_settings.dict()) -# pylint: enable=consider-using-f-string - - -# SQLAlchemy engine object -engine = create_engine(db_url) - -# Create DB if not exists -if not database_exists(db_url): - create_database(db_url) - -# SQLAlchemy Session object -SessionLocal = sessionmaker( - autoflush=False, - bind=engine, -) - -# SQLAlchemy Base object -Base = declarative_base() - - -# FastAPI dependency -async def get_db() -> AsyncGenerator[Session, None]: - """FastAPI dependency - returning database Session object. - Code is copied from the official docs - - Yields: - SQLAlchemy Session object - """ - - db = SessionLocal() - try: - yield db - finally: - db.close() diff --git a/app/sql/models.py b/app/sql/models.py index a654a63..7f8b469 100644 --- a/app/sql/models.py +++ b/app/sql/models.py @@ -1,6 +1,6 @@ from sqlalchemy import Column, String, Integer -from .db import Base +from . import Base class User(Base):