diff --git a/chart/function/create_chart.py b/chart/function/create_chart.py index 5faf618..983da04 100644 --- a/chart/function/create_chart.py +++ b/chart/function/create_chart.py @@ -6,98 +6,112 @@ based on historical data retrieved from the database. from datetime import datetime from matplotlib import pyplot as plt +from scipy.interpolate import make_interp_spline +import numpy as np from function.gen_unique_name import generate_unique_name from database.server import create_pool async def create_chart( - from_currency: str, - conv_currency: str, - start_date: str, - end_date: str + from_currency: str, + conv_currency: str, + start_date: str, + end_date: str ) -> (str, None): - """ - Generates a line chart of currency rates for a given date range. + """ + Generates a line chart of currency rates for a given date range. - The chart shows the exchange rate trend between `from_currency` and - `conv_currency` within the specified `start_date` and `end_date` range. - The generated chart is saved as a PNG file, and the function returns the - file name. If data is invalid or insufficient, the function returns `None`. + The chart shows the exchange rate trend between `from_currency` and + `conv_currency` within the specified `start_date` and `end_date` range. + The generated chart is saved as a PNG file, and the function returns the + file name. If data is invalid or insufficient, the function returns `None`. - Args: - from_currency (str): The base currency (e.g., "USD"). - conv_currency (str): The target currency (e.g., "EUR"). - start_date (str): The start date in the format 'YYYY-MM-DD'. - end_date (str): The end date in the format 'YYYY-MM-DD'. + Args: + from_currency (str): The base currency (e.g., "USD"). + conv_currency (str): The target currency (e.g., "EUR"). + start_date (str): The start date in the format 'YYYY-MM-DD'. + end_date (str): The end date in the format 'YYYY-MM-DD'. - Returns: - str | None: The name of the saved chart file, or `None` if the operation fails. - """ - pool = await create_pool() + Returns: + str | None: The name of the saved chart file, or `None` if the operation fails. + """ + pool = await create_pool() - if not validate_date(start_date) or not validate_date(end_date): - return None + if not validate_date(start_date) or not validate_date(end_date): + return None - start_date_obj = datetime.strptime(start_date, '%Y-%m-%d').date() - end_date_obj = datetime.strptime(end_date, '%Y-%m-%d').date() + start_date_obj = datetime.strptime(start_date, '%Y-%m-%d').date() + end_date_obj = datetime.strptime(end_date, '%Y-%m-%d').date() - async with pool.acquire() as conn: - data = await conn.fetch( - 'SELECT date, rate FROM currency ' - 'WHERE (date BETWEEN $1 AND $2) AND from_currency = $3 AND conv_currency = $4 ORDER BY date', - start_date_obj, - end_date_obj, - from_currency.upper(), - conv_currency.upper() - ) + async with pool.acquire() as conn: + data = await conn.fetch( + 'SELECT date, rate FROM currency ' + 'WHERE (date BETWEEN $1 AND $2) ' + + 'AND from_currency = $3 AND conv_currency = $4 ORDER BY date', + start_date_obj, + end_date_obj, + from_currency.upper(), + conv_currency.upper() + ) - if not data or len(data) <= 1: - return None + if not data or len(data) <= 1: + return None - date, rate = [], [] - fig = plt.gcf() + date, rate = [], [] - for row in data: - date.append(str(row['date'])) - rate.append(row['rate']) - width = 18.5 + (len(date) // 5) * 3 - fig.set_size_inches(width, 9.5) + for row in data: + date.append(row[0]) + rate.append(row[1]) + spline = make_interp_spline(range(len(date)), rate, k=2) + x = np.arange(len(date)) - if rate[0] < rate[-1]: - plt.plot(date, rate, color='green', marker='o') - elif rate[0] > rate[-1]: - plt.plot(date, rate, color='red', marker='o') - else: - plt.plot(date, rate, color='grey') + newx_2 = np.linspace(0, len(date) - 1, 200) + newy_2 = spline(newx_2) + fig, ax = plt.subplots(figsize=(15, 6)) - plt.xlabel('Date') - plt.ylabel('Rate') + for label in (ax.get_xticklabels() + ax.get_yticklabels()): + label.set_fontsize(10) + ax.set_xticks(np.linspace(0, len(date) - 1, 10)) + ax.set_xticklabels( + [ + date[int(i)].strftime('%Y-%m-%d') + for i in np.linspace(0, len(date) - 1, 10).astype(int) + ] + ) - name = await generate_unique_name( - f'{from_currency.upper()}_{conv_currency.upper()}', - datetime.now() - ) + name = await generate_unique_name( + f'{from_currency.upper()}_{conv_currency.upper()}', + datetime.now() + ) - fig.savefig(f'../charts/{name}.png') - fig.clear() + + if rate[0] < rate[-1]: + plt.plot(newx_2, newy_2, color='green') + elif rate[0] > rate[-1]: + plt.plot(newx_2, newy_2, color='red') + else: + plt.plot(newx_2, newy_2, color='grey') - return name + plt.savefig(f'../charts/{name}.png') + fig.clear() + + return name def validate_date(date_str: str) -> bool: - """ - Validates whether the provided string is a valid date in the format 'YYYY-MM-DD'. + """ + Validates whether the provided string is a valid date in the format 'YYYY-MM-DD'. - Args: - date_str (str): The date string to validate. + Args: + date_str (str): The date string to validate. - Returns: - bool: `True` if the string is a valid date, `False` otherwise. - """ - try: - datetime.strptime(date_str, '%Y-%m-%d') - return True - except ValueError: - return False + Returns: + bool: `True` if the string is a valid date, `False` otherwise. + """ + try: + datetime.strptime(date_str, '%Y-%m-%d') + return True + except ValueError: + return False diff --git a/chart/requirements.txt b/chart/requirements.txt index f286fcc..278ef87 100644 --- a/chart/requirements.txt +++ b/chart/requirements.txt @@ -4,4 +4,6 @@ uvicorn~=0.29.0 fastapi[standard]~=0.115.2 starlette~=0.40.0 user_agents==2.2.0 -asyncpg~=0.30.0 \ No newline at end of file +asyncpg~=0.30.0 +scipy~=1.15.2 +numpy~=2.2.0 \ No newline at end of file