Skip to content

Commit

Permalink
feat: streamlit frontend, rudimentary
Browse files Browse the repository at this point in the history
  • Loading branch information
freddysongg committed Nov 23, 2024
1 parent c21a82e commit ce36415
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 0 deletions.
158 changes: 158 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import streamlit as st
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import json
import tensorflow as tf
import torch
from sklearn.preprocessing import MinMaxScaler
from datetime import datetime

# Helper functions
def load_scaler(path):
with open(path, 'r') as f:
scaler_data = json.load(f)
scaler = MinMaxScaler()
scaler.min_ = np.array(scaler_data['min_'])
scaler.scale_ = np.array(scaler_data['scale_'])
return scaler

def load_lstm_model():
return tf.keras.load_model('models/best_lstm_model.keras')

def load_transformer_model(params_path, model_path):
with open(params_path, 'r') as f:
params = json.load(f)
model = TimeSeriesTransformer(
input_size=params['d_model'],
num_layers=params['num_layers'],
num_heads=params['num_heads'],
d_model=params['d_model'],
dim_feedforward=params['dim_feedforward']
)
model.load_state_dict(torch.load(model_path))
return model, params

# TimeSeriesTransformer class definition (same as in your training code)
class TimeSeriesTransformer(torch.nn.Module):
def __init__(self, input_size, num_layers, num_heads, d_model, dim_feedforward):
super(TimeSeriesTransformer, self).__init__()
self.encoder_layer = torch.nn.TransformerEncoderLayer(
d_model=d_model, nhead=num_heads, dim_feedforward=dim_feedforward, batch_first=True
)
self.transformer_encoder = torch.nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.fc = torch.nn.Linear(d_model, 1)

def forward(self, x):
x = self.transformer_encoder(x)
x = self.fc(x[:, -1, :]) # Output of the last time step
return x

# UI Setup
st.set_page_config(
page_title="Café ML Demo",
layout="centered",
initial_sidebar_state="expanded",
)

# Style Settings
st.markdown("""
<style>
body {
background-color: #f4f1ea; /* Cream color */
color: #4e342e; /* Espresso brown */
}
.stButton>button {
background-color: #4e342e;
color: white;
}
.stButton>button:hover {
background-color: #6d4c41;
color: white;
}
</style>
""", unsafe_allow_html=True)

st.title("☕ Café ML Demo")
st.sidebar.title("⚙️ Settings")

# Sidebar options
model_type = st.sidebar.selectbox("Select Model Type", ["LSTM", "Transformer"])
seq_length = st.sidebar.number_input("Sequence Length", min_value=5, max_value=50, value=10, step=1)
uploaded_file = st.sidebar.file_uploader("Upload Test Data (CSV)", type=["csv"])
dark_mode = st.sidebar.checkbox("Enable Dark Mode")

# Dark Mode Styling
if dark_mode:
st.markdown("""
<style>
body {
background-color: #2c2c2c;
color: #f4f1ea;
}
.stButton>button {
background-color: #6d4c41;
color: white;
}
</style>
""", unsafe_allow_html=True)

# Load Data
if uploaded_file:
data = pd.read_csv(uploaded_file)
st.sidebar.write(f"Data preview:")
st.sidebar.write(data.head())
else:
st.sidebar.warning("Upload a CSV file to proceed.")

# Load Models
scaler_path = 'models/scaler.json'
scaler = load_scaler(scaler_path)
lstm_model = None
transformer_model = None
if model_type == "LSTM":
lstm_model = load_lstm_model()
else:
transformer_model, transformer_params = load_transformer_model(
'params/best_ts_transformer_params.json',
'models/best_ts_transformer_model.pt'
)

# Inference
if st.button("Run Inference"):
if not uploaded_file:
st.error("Please upload a test data file first.")
else:
# Scale and process data
scaled_data = scaler.transform(data.values)
sequences = [
scaled_data[i : i + seq_length]
for i in range(len(scaled_data) - seq_length)
]
sequences = np.array(sequences)

# Predict
if model_type == "LSTM":
predictions = lstm_model.predict(sequences)
else:
sequences_torch = torch.FloatTensor(sequences).unsqueeze(-1) # Add feature dim
predictions = transformer_model(sequences_torch).detach().numpy()

# Rescale predictions
predictions_rescaled = scaler.inverse_transform(predictions)

# Visualization
st.success("Inference complete! Here are the results:")
fig = go.Figure()
fig.add_trace(go.Scatter(y=data.values.flatten(), name="Actual", mode="lines"))
fig.add_trace(go.Scatter(y=predictions_rescaled.flatten(), name="Predicted", mode="lines"))
fig.update_layout(
title="Actual vs Predicted",
xaxis_title="Time Steps",
yaxis_title="Values",
template="plotly_dark" if dark_mode else "plotly_white",
)
st.plotly_chart(fig)

# Footer
st.markdown("#### Made with ❤️ for CaféCast")
20 changes: 20 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
absl-py==2.1.0
altair==5.4.1
annotated-types==0.7.0
anyio==4.6.2.post1
astunparse==1.6.3
attrs==24.2.0
bayesian-optimization==2.0.0
blinker==1.9.0
cachetools==5.5.0
certifi==2024.8.30
charset-normalizer==3.4.0
click==8.1.7
Expand All @@ -13,6 +17,8 @@ filelock==3.16.1
flatbuffers==24.3.25
fsspec==2024.10.0
gast==0.6.0
gitdb==4.0.11
GitPython==3.1.43
google-pasta==0.2.0
grpcio==1.68.0
h11==0.14.0
Expand All @@ -21,6 +27,8 @@ idna==3.10
importlib_metadata==8.5.0
Jinja2==3.1.4
joblib==1.4.2
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
keras==3.6.0
libclang==18.1.1
Markdown==3.7
Expand All @@ -30,35 +38,47 @@ mdurl==0.1.2
ml-dtypes==0.4.1
mpmath==1.3.0
namex==0.0.8
narwhals==1.14.1
networkx==3.2.1
numpy==2.0.2
opt_einsum==3.4.0
optree==0.13.1
packaging==24.2
pandas==2.2.3
patsy==1.0.1
pillow==11.0.0
plotly==5.24.1
protobuf==5.28.3
pyarrow==18.0.0
pydantic==2.10.1
pydantic_core==2.27.1
pydeck==0.9.1
Pygments==2.18.0
python-dateutil==2.9.0.post0
pytz==2024.2
referencing==0.35.1
requests==2.32.3
rich==13.9.4
rpds-py==0.21.0
scikit-learn==1.5.2
scipy==1.13.1
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
starlette==0.41.3
statsmodels==0.14.4
streamlit==1.40.1
sympy==1.13.1
tenacity==9.0.0
tensorboard==2.18.0
tensorboard-data-server==0.7.2
tensorflow==2.18.0
tensorflow-io-gcs-filesystem==0.37.1
termcolor==2.5.0
threadpoolctl==3.5.0
toml==0.10.2
torch==2.5.1
tornado==6.4.2
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
Expand Down
File renamed without changes.

0 comments on commit ce36415

Please sign in to comment.