From 81a5f6271962d81b3f32910006b8f2ded032037e Mon Sep 17 00:00:00 2001 From: Freddy Song Date: Tue, 19 Nov 2024 15:56:54 -0800 Subject: [PATCH] feat: arima model + best params tuning --- __init__.py | 0 data/modify_dataset.py | 17 +++++++ main.py | 13 ----- requirements.txt | 106 ----------------------------------------- src/main.py | 52 ++++++++++++++++++++ src/tune_arima.py | 27 +++++++++++ 6 files changed, 96 insertions(+), 119 deletions(-) create mode 100644 __init__.py create mode 100644 data/modify_dataset.py delete mode 100644 main.py create mode 100644 src/main.py create mode 100644 src/tune_arima.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/modify_dataset.py b/data/modify_dataset.py new file mode 100644 index 0000000..4ec90d3 --- /dev/null +++ b/data/modify_dataset.py @@ -0,0 +1,17 @@ +import pandas as pd + +def prepare_data(file_path, date_col='transaction_date', time_col='transaction_time'): + data = pd.read_excel(file_path) + + data[date_col] = pd.to_datetime(data[date_col]) + + data.set_index(date_col, inplace=True) + + if time_col in data.columns: + data['transaction_hour'] = data[time_col].apply(lambda x: x.hour) + + return data + +if __name__ == "__main__": + data = prepare_data('../data/cafecast_data.xlsx') + print(data.info()) diff --git a/main.py b/main.py deleted file mode 100644 index 5705b52..0000000 --- a/main.py +++ /dev/null @@ -1,13 +0,0 @@ -import warnings -warnings.filterwarnings("ignore", "urllib3 v2 only supports OpenSSL") - -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -import tensorflow as tf -import torch - -data = pd.read_excel('data/cafecast_data.xlsx') - -print(data.head()) diff --git a/requirements.txt b/requirements.txt index 366b8ca..e69de29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,106 +0,0 @@ -absl-py==2.1.0 -aiohttp==3.9.5 -aiosignal==1.3.1 -altair==5.3.0 -astunparse==1.6.3 -async-timeout==4.0.3 -attrs==23.2.0 -blinker==1.8.2 -cachetools==5.3.3 -certifi==2024.6.2 -charset-normalizer==3.3.2 -click==8.1.7 -datasets==2.19.2 -dill==0.3.8 -et_xmlfile==2.0.0 -filelock==3.14.0 -Flask==3.0.3 -flatbuffers==24.3.25 -frozenlist==1.4.1 -fsspec==2024.3.1 -gast==0.6.0 -gitdb==4.0.11 -GitPython==3.1.43 -google-pasta==0.2.0 -grpcio==1.65.4 -h5py==3.11.0 -huggingface-hub==0.23.2 -idna==3.7 -itsdangerous==2.2.0 -Jinja2==3.1.4 -jsonschema==4.22.0 -jsonschema-specifications==2023.12.1 -keras==3.4.1 -libclang==18.1.1 -Markdown==3.6 -markdown-it-py==3.0.0 -MarkupSafe==2.1.5 -mdurl==0.1.2 -ml-dtypes==0.3.2 -mpmath==1.3.0 -multidict==6.0.5 -multiprocess==0.70.16 -mypy==1.11.1 -mypy-extensions==1.0.0 -namex==0.0.8 -networkx==3.3 -numpy==1.26.4 -openpyxl==3.1.5 -opt-einsum==3.3.0 -optree==0.12.1 -packaging==24.0 -pandas==2.2.2 -pillow==10.3.0 -protobuf==4.25.3 -pyaml==24.7.0 -pyarrow==16.1.0 -pyarrow-hotfix==0.6 -pydeck==0.9.1 -Pygments==2.18.0 -PyMuPDF==1.24.5 -PyMuPDFb==1.24.3 -pypng==0.20220715.0 -python-dateutil==2.9.0.post0 -pytz==2024.1 -PyYAML==6.0.1 -qrcode==7.4.2 -referencing==0.35.1 -regex==2024.5.15 -requests==2.32.3 -rich==13.7.1 -rpds-py==0.18.1 -safetensors==0.4.3 -six==1.16.0 -smmap==5.0.1 -streamlit==1.35.0 -suno-bark @ git+https://github.com/suno-ai/bark.git@6cd7f0ccd75fbbd9c84c8ce14bf4e700a573eef8 -sympy==1.12.1 -tenacity==8.3.0 -tensorboard==2.16.2 -tensorboard-data-server==0.7.2 -tensorflow==2.16.2 -tensorflow-io-gcs-filesystem==0.37.1 -tensorflow-macos==2.16.2 -termcolor==2.4.0 -tokenizers==0.19.1 -toml==0.10.2 -tomli==2.0.1 -toolz==0.12.1 -torch==2.3.0 -tornado==6.4 -tqdm==4.66.4 -transformers==4.41.2 -types-click==7.1.8 -types-Flask==1.1.6 -types-Flask-Cors==4.0.0.20240523 -types-Jinja2==2.11.9 -types-MarkupSafe==1.1.10 -types-requests==2.32.0.20240712 -types-Werkzeug==1.0.9 -typing_extensions==4.12.1 -tzdata==2024.1 -urllib3==1.26.16 -Werkzeug==3.0.3 -wrapt==1.16.0 -xxhash==3.4.1 -yarl==1.9.4 diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..6c7d40c --- /dev/null +++ b/src/main.py @@ -0,0 +1,52 @@ +import warnings +warnings.filterwarnings("ignore", "urllib3 v2 only supports OpenSSL") + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from statsmodels.tsa.arima.model import ARIMA +from sklearn.metrics import mean_absolute_error, mean_squared_error +from data.modify_dataset import prepare_data + +data = prepare_data('data/cafecast_data.xlsx') +daily_data = data.resample('D')['transaction_qty'].sum() + +# 80/20 split +train_size = int(len(daily_data) * 0.8) +train, test = daily_data[:train_size], daily_data[train_size:] + +plt.figure(figsize=(14, 7)) +plt.plot(train, label='Training Data') +plt.plot(test, label='Testing Data', color='orange') + +plt.title('Train/Test Split for Time Series') +plt.xlabel('Date') +plt.ylabel('Transaction Quantity') +plt.legend() +plt.grid(True) +plt.show() + +best_p, best_d, best_q = 1, 1, 2 +model = ARIMA(train, order=(best_p, best_d, best_q)) +model_fit = model.fit() + +print(model_fit.summary()) + +forecast = model_fit.forecast(steps=len(test)) + +# Plot actual vs forecast +plt.figure(figsize=(14, 7)) +plt.plot(test.index, test, label='Actual', color='blue') +plt.plot(test.index, forecast, label='Forecast', color='red') +plt.title('Actual vs Forecasted Transaction Quantities') +plt.xlabel('Date') +plt.ylabel('Transaction Quantity') +plt.legend() +plt.grid(True) +plt.show() + +# Evaluate model performance +mae = mean_absolute_error(test, forecast) +rmse = np.sqrt(mean_squared_error(test, forecast)) +print(f'Mean Absolute Error (MAE): {mae:.2f}') +print(f'Root Mean Squared Error (RMSE): {rmse:.2f}') diff --git a/src/tune_arima.py b/src/tune_arima.py new file mode 100644 index 0000000..eb505f5 --- /dev/null +++ b/src/tune_arima.py @@ -0,0 +1,27 @@ +import warnings +warnings.filterwarnings("ignore", "urllib3 v2 only supports OpenSSL") + +import pmdarima as pm +import pandas as pd +from data.modify_dataset import prepare_data + +data = prepare_data('data/cafecast_data.xlsx') +daily_data = data.resample('D')['transaction_qty'].sum() + +# Train-test split +train_size = int(len(daily_data) * 0.8) +train, test = daily_data[:train_size], daily_data[train_size:] + +# Use auto_arima to find the optimal (p, d, q) +model = pm.auto_arima( + train, + seasonal=False, + stepwise=True, + trace=True, + suppress_warnings=True, + error_action="ignore", + max_p=5, max_q=5, + max_d=2 +) + +print(f'Optimal ARIMA Order: {model.order}') \ No newline at end of file