Skip to content

Commit

Permalink
Merge pull request #3 from MaastrichtU-IDS/secondary
Browse files Browse the repository at this point in the history
even more changes to the EDA scripts to better handle dates
  • Loading branch information
anas-elghafari authored Mar 7, 2025
2 parents 69ebb5c + 3b57d18 commit 26a11fd
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 27 deletions.
49 changes: 36 additions & 13 deletions backend/src/eda_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def c2_save_to_json(cohort_id: str) -> str:

def c3_eda_data_profiling(cohort_id: str) -> str:
raw_script = """
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -341,7 +340,7 @@ def create_save_graph(df, varname, stats_text, vartype, category_mapping=None):
# Right: Plot histogram
sns.histplot(df[varname].dropna(), kde=True, ax=axes[1])
axes[0].set_title(f"Statistics Summary for {varname}", fontsize=12)
axes[0].set_title(f"Statistics Summary for {varname}.upper()", fontsize=12)
axes[1].tick_params(axis='x')
# Save the figure for the current feature
Expand All @@ -352,24 +351,48 @@ def create_save_graph(df, varname, stats_text, vartype, category_mapping=None):
elif vartype == 'datetime':
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
# Left: Display Summary Stats
props = dict(boxstyle="round,pad=0.5", facecolor="whitesmoke", alpha=0.8, edgecolor="lightgray")
text_obj = axes[0].text(0.05, 0.95, '\\n'.join(stats_text), transform=axes[0].transAxes, fontsize=11, va='top', ha='left',
family='monospace', bbox=props, wrap=True, linespacing=1.5)
if hasattr(text_obj, "_get_wrap_line_width"):
text_obj._get_wrap_line_width = lambda: 400
#axes[0].text(0.05, 0.9, , fontsize=10, va='top', ha='left', linespacing=1.2, family='monospace', wrap=True)
axes[0].axis("off")
axes[0].set_title(f"Statistics Summary for {varname}", fontsize=12)
# Right: Plot histogram
sorted_dates = sorted(df[varname].dropna())
axes[1].xaxis.set_major_locator(mdates.MonthLocator()) # Show ticks at month intervals
axes[1].xaxis.set_minor_locator(mdates.WeekdayLocator()) # Minor ticks at weeks
axes[0].set_title(f"Statistics Summary for {varname}.upper()", fontsize=12)
date_vals = pd.to_datetime(df[varname].dropna(), format='%Y-%m-%d')
date_nums = mdates.date2num(date_vals)
min_date = date_vals.min()
max_date = date_vals.max()
date_range = max_date - min_date
# bin frequency based on date range
if date_range.days > 365:
bin_freq = 'M' # Monthly bins for ranges > 1 year
elif date_range.days > 90:
bin_freq = 'W' # Weekly bins for ranges > 3 months
else:
bin_freq = 'D' # Daily bins for shorter ranges
bins = mdates.date2num(pd.date_range(min_date, max_date, freq=bin_freq))
axes[1].hist(date_nums, bins=bins, alpha=0.7)
axes[1].set_title(f"Distribution of {varname}.upper()", fontsize=12)
axes[1].xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.xticks(rotation=90)
if date_range.days > 365*2:
axes[1].xaxis.set_major_locator(mdates.YearLocator())
elif date_range.days > 180:
axes[1].xaxis.set_major_locator(mdates.MonthLocator(interval=3)) # Quarterly
else:
axes[1].xaxis.set_major_locator(mdates.MonthLocator()) # Monthly
axes[1].tick_params(axis='x', rotation=90)
plt.tight_layout()
plt.savefig(f"/output/{varname.lower()}.png")
print(f"figure for {varname} saved!! ")
elif vartype == 'categorical':
Expand All @@ -392,7 +415,7 @@ def create_save_graph(df, varname, stats_text, vartype, category_mapping=None):
if not value_counts.empty:
colors = sns.color_palette("husl", len(value_counts))
ax = value_counts.plot(kind='bar', color=colors, edgecolor='black', ax=axes[1])
axes[0].set_title(f"Statistics Summary for {varname}", fontsize=12)
axes[0].set_title(f"Statistics Summary for {varname}.upper()", fontsize=12)
ax.set_xlabel("Categories")
ax.set_ylabel("Count")
Expand Down
52 changes: 38 additions & 14 deletions backend/src/script_edited.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -201,7 +202,7 @@ def create_save_graph(df, varname, stats_text, vartype, category_mapping=None):

# Right: Plot histogram
sns.histplot(df[varname].dropna(), kde=True, ax=axes[1])
axes[0].set_title(f"Statistics Summary for {varname}", fontsize=12)
axes[0].set_title(f"Statistics Summary for {varname}.upper()", fontsize=12)
axes[1].tick_params(axis='x')

# Save the figure for the current feature
Expand All @@ -212,25 +213,48 @@ def create_save_graph(df, varname, stats_text, vartype, category_mapping=None):
elif vartype == 'datetime':
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Left: Display Summary Stats

props = dict(boxstyle="round,pad=0.5", facecolor="whitesmoke", alpha=0.8, edgecolor="lightgray")
text_obj = axes[0].text(0.05, 0.95, '\n'.join(stats_text), transform=axes[0].transAxes, fontsize=11, va='top', ha='left',
family='monospace', bbox=props, wrap=True, linespacing=1.5)
if hasattr(text_obj, "_get_wrap_line_width"):
text_obj._get_wrap_line_width = lambda: 400
#axes[0].text(0.05, 0.9, , fontsize=10, va='top', ha='left', linespacing=1.2, family='monospace', wrap=True)
axes[0].axis("off")

# Right: Plot histogram
sns.histplot(df[varname].dropna(), kde=True, ax=axes[1])
axes[0].set_title(f"Statistics Summary for {varname}", fontsize=12)
axes[1].tick_params(axis='x')
N = 10 # Show every 10th date
indices = np.arange(0, len(sorted_dates), N)
axes.set_xticks([sorted_dates[i] for i in indices])
axes.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.xticks(rotation=90)
axes[0].set_title(f"Statistics Summary for {varname}.upper()", fontsize=12)
date_vals = pd.to_datetime(df[varname].dropna(), format='%Y-%m-%d')

date_nums = mdates.date2num(date_vals)

min_date = date_vals.min()
max_date = date_vals.max()
date_range = max_date - min_date

# bin frequency based on date range
if date_range.days > 365:
bin_freq = 'M' # Monthly bins for ranges > 1 year
elif date_range.days > 90:
bin_freq = 'W' # Weekly bins for ranges > 3 months
else:
bin_freq = 'D' # Daily bins for shorter ranges

bins = mdates.date2num(pd.date_range(min_date, max_date, freq=bin_freq))

axes[1].hist(date_nums, bins=bins, alpha=0.7)
axes[1].set_title(f"Distribution of {varname}.upper()", fontsize=12)
axes[1].xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))

if date_range.days > 365*2:
axes[1].xaxis.set_major_locator(mdates.YearLocator())
elif date_range.days > 180:
axes[1].xaxis.set_major_locator(mdates.MonthLocator(interval=3)) # Quarterly
else:
axes[1].xaxis.set_major_locator(mdates.MonthLocator()) # Monthly

axes[1].tick_params(axis='x', rotation=90)

plt.tight_layout()
plt.savefig(f"/output/{varname.lower()}.png")
print(f"figure for {varname} saved!! ")


elif vartype == 'categorical':
Expand All @@ -253,7 +277,7 @@ def create_save_graph(df, varname, stats_text, vartype, category_mapping=None):
if not value_counts.empty:
colors = sns.color_palette("husl", len(value_counts))
ax = value_counts.plot(kind='bar', color=colors, edgecolor='black', ax=axes[1])
axes[0].set_title(f"Statistics Summary for {varname}", fontsize=12)
axes[0].set_title(f"Statistics Summary for {varname}.upper()", fontsize=12)
ax.set_xlabel("Categories")
ax.set_ylabel("Count")

Expand Down

0 comments on commit 26a11fd

Please sign in to comment.