Skip to content

Commit

Permalink
feat: enhance rate calculation to use minimum rates across all determ…
Browse files Browse the repository at this point in the history
…inant keys

- Merge base_rate and calculate_period_rates functions into a single calculate_rates function
- Modify rate calculation to use minimum rates from all determinant keys instead of first key only
- Extract report generation into a separate function for better code organization
- Retain all original logging functionality
- Add input validation and proper error handling
  • Loading branch information
Alex870521 committed Feb 5, 2025
1 parent 3141083 commit b335abf
Show file tree
Hide file tree
Showing 10 changed files with 433 additions and 112 deletions.
2 changes: 2 additions & 0 deletions AeroViz/dataProcess/Optical/_absorption.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ def _absCoe(df, instru, specified_band: list):

band_AE33 = np.array([370, 470, 520, 590, 660, 880, 950])
band_BC1054 = np.array([370, 430, 470, 525, 565, 590, 660, 700, 880, 950])
band_MA350 = np.array([375, 470, 528, 625, 880])

MAE_AE33 = np.array([18.47, 14.54, 13.14, 11.58, 10.35, 7.77, 7.19]) * 1e-3
MAE_BC1054 = np.array([18.48, 15.90, 14.55, 13.02, 12.10, 11.59, 10.36, 9.77, 7.77, 7.20]) * 1e-3
MAE_MA350 = np.array([24.069, 19.070, 17.028, 14.091, 10.120]) * 1e-3

band = band_AE33 if instru == 'AE33' else band_BC1054
MAE = MAE_AE33 if instru == 'AE33' else MAE_BC1054
Expand Down
2 changes: 1 addition & 1 deletion AeroViz/plot/templates/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .ammonium_rich import ammonium_rich
from .contour import *
from .corr_matrix import corr_matrix
from .corr_matrix import corr_matrix, cross_corr_matrix
from .diurnal_pattern import *
from .koschmieder import *
from .metal_heatmap import metal_heatmaps, process_data_with_two_df
170 changes: 168 additions & 2 deletions AeroViz/plot/templates/corr_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@

from AeroViz.plot.utils import *

__all__ = ['corr_matrix']
__all__ = ['corr_matrix', 'cross_corr_matrix']


@set_figure
def corr_matrix(data: pd.DataFrame,
cmap: str = "RdBu",
ax: Axes | None = None,
items_order: list = None, # 新增參數用於指定順序
**kwargs
) -> tuple[Figure, Axes]:
fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)

_corr = data.corr()
breakpoint()
corr = pd.melt(_corr.reset_index(), id_vars='index')
corr.columns = ['x', 'y', 'value']

Expand Down Expand Up @@ -94,8 +96,172 @@ def value_to_color(val):
label='p < 0.05'
)

ax.legend(handles=[point2], labels=['p < 0.05'], bbox_to_anchor=(0.05, 1, 0.1, 0.05))
ax.legend(handles=[point2], labels=['p < 0.05'], bbox_to_anchor=(0.02, 1, 0.05, 0.05))

plt.show()

return fig, ax


@set_figure(figsize=(6, 6))
def cross_corr_matrix(data1: pd.DataFrame,
data2: pd.DataFrame,
cmap: str = "RdBu",
ax: Axes | None = None,
items_order: list = None, # 新增參數用於指定順序
**kwargs
) -> tuple[Figure, Axes]:
"""
Create a correlation matrix between two different DataFrames.
Parameters:
-----------
data1 : pd.DataFrame
First DataFrame
data2 : pd.DataFrame
Second DataFrame
cmap : str, optional
Color map for the correlation matrix
ax : Axes, optional
Matplotlib axes to plot on
items_order : list, optional
List specifying the order of items to display
**kwargs : dict
Additional keyword arguments
"""
if ax is None:
fig_kws = kwargs.get('fig_kws', {})
default_figsize = fig_kws.get('figsize', (8, 8))
fig = plt.figure(figsize=default_figsize)
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
else:
fig = ax.get_figure()

# 如果沒有指定順序,使用原始列名順序
if items_order is None:
x_labels = list(data1.columns)
y_labels = list(data2.columns)
else:
# 使用指定順序,但只包含實際存在於數據中的列
x_labels = [item for item in items_order if item in data1.columns]
y_labels = [item for item in items_order if item in data2.columns]

# Calculate cross-correlation between the two DataFrames
correlations = []
p_values_list = []

for col1 in x_labels: # 使用指定順序的列名
for col2 in y_labels:
try:
mask = ~(np.isnan(data1[col1]) | np.isnan(data2[col2]))
if mask.sum() > 2:
corr, p_val = pearsonr(data1[col1][mask], data2[col2][mask])
else:
corr, p_val = np.nan, np.nan
except Exception as e:
print(f"Error calculating correlation for {col1} and {col2}: {str(e)}")
corr, p_val = np.nan, np.nan

correlations.append({
'x': col1,
'y': col2,
'value': corr
})
if p_val is not None and p_val < 0.05:
p_values_list.append({
'x': col1,
'y': col2,
'value': p_val
})

corr = pd.DataFrame(correlations)
p_values = pd.DataFrame(p_values_list)

# Create mapping using the specified order
x_to_num = {label: i for i, label in enumerate(x_labels)}
y_to_num = {label: i for i, label in enumerate(y_labels)}

# 調整標籤顯示
ax.set_xticks([x_to_num[v] for v in x_labels])
ax.set_xticklabels(x_labels, rotation=45, ha='right')
ax.set_yticks([y_to_num[v] for v in y_labels])
ax.set_yticklabels(y_labels)

ax.grid(False, 'major')
ax.grid(True, 'minor')
ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True)

ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5])
ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5])

# Color mapping
n_colors = 256
palette = sns.color_palette(cmap, n_colors=n_colors)
color_min, color_max = [-1, 1]

def value_to_color(val):
if pd.isna(val):
return (1, 1, 1)
val_position = float((val - color_min)) / (color_max - color_min)
val_position = np.clip(val_position, 0, 1)
ind = int(val_position * (n_colors - 1))
return palette[ind]

# Plot correlation squares
x_coords = corr['x'].map(x_to_num)
y_coords = corr['y'].map(y_to_num)
sizes = corr['value'].abs().fillna(0) * 70
colors = [value_to_color(val) for val in corr['value']]

point = ax.scatter(
x=x_coords,
y=y_coords,
s=sizes,
c=colors,
marker='s',
label='$R^{2}$'
)

# 調整顏色軸的位置和大小
cax = fig.add_axes([0.91, 0.1, 0.02, 0.8])
axes_image = plt.cm.ScalarMappable(cmap=colormaps[cmap])
cbar = plt.colorbar(mappable=axes_image, cax=cax, label=r'$R^{2}$')
cbar.set_ticks([0, 0.25, 0.5, 0.75, 1])
cbar.set_ticklabels(np.linspace(-1, 1, 5))

# Plot significance markers
if not p_values.empty:
point2 = ax.scatter(
x=p_values['x'].map(x_to_num),
y=p_values['y'].map(y_to_num),
s=10,
marker='*',
color='k',
label='p < 0.05'
)
ax.legend(handles=[point2], labels=['p < 0.05'],
bbox_to_anchor=(0.005, 1.04), loc='upper left')

# Add labels
ax.set_xlabel('NZ', labelpad=10)
ax.set_ylabel('FS', labelpad=10)

plt.show()

return fig, ax


if __name__ == '__main__':
import pandas as pd
from pandas import to_numeric

df_NZ = pd.read_csv('/Users/chanchihyu/Desktop/NZ_minion_202402-202411.csv', parse_dates=True, index_col=0)
df_FS = pd.read_csv('/Users/chanchihyu/Desktop/FS_minion_202402-202411.csv', parse_dates=True, index_col=0)

items = ['Ext', 'Sca', 'Abs', 'PNC', 'PSC', 'PVC', 'SO2', 'NO', 'NOx', 'NO2', 'CO', 'O3', 'THC', 'NMHC', 'CH4',
'PM10', 'PM2.5', 'WS', 'AT', 'RH',
'OC', 'EC', 'Na+', 'NH4+', 'NO3-', 'SO42-', 'Al', 'Si', 'Ca', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Cu', 'Zn']
df_NZ = df_NZ.apply(to_numeric, errors='coerce')

corr_matrix(df_NZ[items], items_order=items)
21 changes: 15 additions & 6 deletions AeroViz/plot/templates/metal_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ def normalize_and_split(df, df2):
return df, df2


@set_figure(figsize=(12, 3), fs=6)
@set_figure(figsize=(6, 3), fs=8, fw='normal')
def metal_heatmaps(df,
process=True,
major_freq='24h',
minor_freq='12h',
major_freq='10d',
minor_freq='1d',
cmap='jet',
ax: Axes | None = None,
**kwargs
Expand All @@ -131,7 +131,7 @@ def metal_heatmaps(df,

fig, ax = plt.subplots(**kwargs.get('fig_kws', {})) if ax is None else (ax.get_figure(), ax)

sns.heatmap(df.T, vmin=None, vmax=3, cmap=cmap, xticklabels=False, yticklabels=True,
sns.heatmap(df.T, vmin=None, vmax=3, cmap=cmap, xticklabels=True, yticklabels=True,
cbar_kws={'label': 'Z score', "pad": 0.02})
ax.grid(color='gray', linestyle='-', linewidth=0.3)

Expand All @@ -142,14 +142,23 @@ def metal_heatmaps(df,
# Set the major and minor ticks
ax.set_xticks(ticks=[df.index.get_loc(t) for t in major_tick])
ax.set_xticks(ticks=[df.index.get_loc(t) for t in minor_tick], minor=True)
ax.set_xticklabels(major_tick.strftime('%F'))
ax.set_xticklabels(major_tick.strftime('%F'), rotation=0)
ax.tick_params(axis='y', rotation=0)

ax.set(xlabel='',
ylabel='',
ylabel='Trace metals',
title=kwargs.get('title', None)
)

if kwargs.get('savefig'):
plt.savefig(kwargs.get('savefig'), dpi=600)

plt.show()

return fig, ax


if __name__ == '__main__':
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
plt.title('text', font={'weight': 'bold'})
plt.show()
Loading

0 comments on commit b335abf

Please sign in to comment.