tmpl-fastapi/app/sql/db.py

65 lines
1.4 KiB
Python
Raw Normal View History

2023-02-20 11:09:43 +03:00
from typing import AsyncGenerator
2023-02-19 16:49:44 +03:00
from pydantic import BaseSettings
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import database_exists
from sqlalchemy_utils import create_database
2023-02-19 16:49:44 +03:00
2023-02-27 18:05:04 +03:00
# Database configuration
2023-02-19 16:49:44 +03:00
class SqlSettings(BaseSettings):
db_host: str = '${REPO_NAME_SNAKE}_db'
db_port: int = 3306
db_user: str = '${REPO_NAME_SNAKE}'
db_password: str = ''
db_database: str = '${REPO_NAME_SNAKE}'
2023-02-27 18:05:57 +03:00
2023-02-19 16:49:44 +03:00
sql_settings = SqlSettings()
2023-02-27 18:05:04 +03:00
# DB connection URL
# pylint: disable=consider-using-f-string
2023-02-19 16:49:44 +03:00
db_url = (
'mysql://{db_user}:{db_password}@'
2023-02-20 11:09:43 +03:00
'{db_host}:{db_port}/{db_database}'
2023-02-19 16:49:44 +03:00
).format(**sql_settings.dict())
# pylint: enable=consider-using-f-string
2023-02-19 16:49:44 +03:00
2023-02-20 11:09:43 +03:00
2023-02-27 18:05:04 +03:00
# SQLAlchemy engine object
2023-02-19 16:49:44 +03:00
engine = create_engine(db_url)
2023-02-27 18:05:04 +03:00
# Create DB if not exists
2023-02-20 11:09:43 +03:00
if not database_exists(db_url):
create_database(db_url)
2023-02-27 18:05:04 +03:00
# SQLAlchemy Session object
2023-02-19 16:49:44 +03:00
SessionLocal = sessionmaker(
autoflush=False,
bind=engine,
)
2023-02-20 11:09:43 +03:00
2023-02-27 18:05:04 +03:00
# SQLAlchemy Base object
2023-02-19 16:49:44 +03:00
Base = declarative_base()
2023-02-27 18:05:04 +03:00
# FastAPI dependency
2023-02-20 11:09:43 +03:00
async def get_db() -> AsyncGenerator[Session, None]:
2023-02-19 16:49:44 +03:00
"""FastAPI dependency
returning database Session object.
Code is copied from the official docs
Yields:
SQLAlchemy Session object
"""
db = SessionLocal()
try:
yield db
finally:
db.close()