Skip to content

Commit

Permalink
feat: arima model + best params tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
freddysongg committed Nov 19, 2024
1 parent c10d8ae commit 81a5f62
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 119 deletions.
Empty file added __init__.py
Empty file.
17 changes: 17 additions & 0 deletions data/modify_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pandas as pd

def prepare_data(file_path, date_col='transaction_date', time_col='transaction_time'):
data = pd.read_excel(file_path)

data[date_col] = pd.to_datetime(data[date_col])

data.set_index(date_col, inplace=True)

if time_col in data.columns:
data['transaction_hour'] = data[time_col].apply(lambda x: x.hour)

return data

if __name__ == "__main__":
data = prepare_data('../data/cafecast_data.xlsx')
print(data.info())
13 changes: 0 additions & 13 deletions main.py

This file was deleted.

106 changes: 0 additions & 106 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,106 +0,0 @@
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
altair==5.3.0
astunparse==1.6.3
async-timeout==4.0.3
attrs==23.2.0
blinker==1.8.2
cachetools==5.3.3
certifi==2024.6.2
charset-normalizer==3.3.2
click==8.1.7
datasets==2.19.2
dill==0.3.8
et_xmlfile==2.0.0
filelock==3.14.0
Flask==3.0.3
flatbuffers==24.3.25
frozenlist==1.4.1
fsspec==2024.3.1
gast==0.6.0
gitdb==4.0.11
GitPython==3.1.43
google-pasta==0.2.0
grpcio==1.65.4
h5py==3.11.0
huggingface-hub==0.23.2
idna==3.7
itsdangerous==2.2.0
Jinja2==3.1.4
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
keras==3.4.1
libclang==18.1.1
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
ml-dtypes==0.3.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
mypy==1.11.1
mypy-extensions==1.0.0
namex==0.0.8
networkx==3.3
numpy==1.26.4
openpyxl==3.1.5
opt-einsum==3.3.0
optree==0.12.1
packaging==24.0
pandas==2.2.2
pillow==10.3.0
protobuf==4.25.3
pyaml==24.7.0
pyarrow==16.1.0
pyarrow-hotfix==0.6
pydeck==0.9.1
Pygments==2.18.0
PyMuPDF==1.24.5
PyMuPDFb==1.24.3
pypng==0.20220715.0
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
qrcode==7.4.2
referencing==0.35.1
regex==2024.5.15
requests==2.32.3
rich==13.7.1
rpds-py==0.18.1
safetensors==0.4.3
six==1.16.0
smmap==5.0.1
streamlit==1.35.0
suno-bark @ git+https://github.com/suno-ai/bark.git@6cd7f0ccd75fbbd9c84c8ce14bf4e700a573eef8
sympy==1.12.1
tenacity==8.3.0
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.2
tensorflow-io-gcs-filesystem==0.37.1
tensorflow-macos==2.16.2
termcolor==2.4.0
tokenizers==0.19.1
toml==0.10.2
tomli==2.0.1
toolz==0.12.1
torch==2.3.0
tornado==6.4
tqdm==4.66.4
transformers==4.41.2
types-click==7.1.8
types-Flask==1.1.6
types-Flask-Cors==4.0.0.20240523
types-Jinja2==2.11.9
types-MarkupSafe==1.1.10
types-requests==2.32.0.20240712
types-Werkzeug==1.0.9
typing_extensions==4.12.1
tzdata==2024.1
urllib3==1.26.16
Werkzeug==3.0.3
wrapt==1.16.0
xxhash==3.4.1
yarl==1.9.4
52 changes: 52 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import warnings
warnings.filterwarnings("ignore", "urllib3 v2 only supports OpenSSL")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.arima.model import ARIMA
from sklearn.metrics import mean_absolute_error, mean_squared_error
from data.modify_dataset import prepare_data

data = prepare_data('data/cafecast_data.xlsx')
daily_data = data.resample('D')['transaction_qty'].sum()

# 80/20 split
train_size = int(len(daily_data) * 0.8)
train, test = daily_data[:train_size], daily_data[train_size:]

plt.figure(figsize=(14, 7))
plt.plot(train, label='Training Data')
plt.plot(test, label='Testing Data', color='orange')

plt.title('Train/Test Split for Time Series')
plt.xlabel('Date')
plt.ylabel('Transaction Quantity')
plt.legend()
plt.grid(True)
plt.show()

best_p, best_d, best_q = 1, 1, 2
model = ARIMA(train, order=(best_p, best_d, best_q))
model_fit = model.fit()

print(model_fit.summary())

forecast = model_fit.forecast(steps=len(test))

# Plot actual vs forecast
plt.figure(figsize=(14, 7))
plt.plot(test.index, test, label='Actual', color='blue')
plt.plot(test.index, forecast, label='Forecast', color='red')
plt.title('Actual vs Forecasted Transaction Quantities')
plt.xlabel('Date')
plt.ylabel('Transaction Quantity')
plt.legend()
plt.grid(True)
plt.show()

# Evaluate model performance
mae = mean_absolute_error(test, forecast)
rmse = np.sqrt(mean_squared_error(test, forecast))
print(f'Mean Absolute Error (MAE): {mae:.2f}')
print(f'Root Mean Squared Error (RMSE): {rmse:.2f}')
27 changes: 27 additions & 0 deletions src/tune_arima.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import warnings
warnings.filterwarnings("ignore", "urllib3 v2 only supports OpenSSL")

import pmdarima as pm
import pandas as pd
from data.modify_dataset import prepare_data

data = prepare_data('data/cafecast_data.xlsx')
daily_data = data.resample('D')['transaction_qty'].sum()

# Train-test split
train_size = int(len(daily_data) * 0.8)
train, test = daily_data[:train_size], daily_data[train_size:]

# Use auto_arima to find the optimal (p, d, q)
model = pm.auto_arima(
train,
seasonal=False,
stepwise=True,
trace=True,
suppress_warnings=True,
error_action="ignore",
max_p=5, max_q=5,
max_d=2
)

print(f'Optimal ARIMA Order: {model.order}')

0 comments on commit 81a5f62

Please sign in to comment.