Skip to content

Commit

Permalink
ENH set default colormap
Browse files Browse the repository at this point in the history
  • Loading branch information
joaopfonseca committed Nov 13, 2024
1 parent b25a419 commit e3ee8a1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
31 changes: 15 additions & 16 deletions sharp/visualization/_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def strata_boxplots(
feature_names=None,
n_strata=5,
gap_size=1,
cmap=None,
cmap="Pastel1",
ax=None,
**kwargs,
):
Expand Down Expand Up @@ -50,38 +50,37 @@ def strata_boxplots(

df["target_binned"] = df["target_binned"].str.replace("<", "$<$")

colors = [
plt.get_cmap(cmap)(i / len(feature_names)) for i in range(len(feature_names))
]
colors = [plt.get_cmap(cmap)(i) for i in range(len(feature_names))]
bin_names = df["target_binned"].unique()
pos_increment = 1 / (len(feature_names) + gap_size)
boxes = []

for i, bin_name in enumerate(bin_names):
box = ax.boxplot(
[df[df["target_binned"] == bin_name][feature] for feature in feature_names],
box = plt.boxplot(
df[df["target_binned"] == bin_name][feature_names],
widths=pos_increment,
positions=[i + pos_increment * n for n in range(len(feature_names))],
patch_artist=True,
medianprops={"color": "black"},
boxprops={"facecolor": "C0", "edgecolor": "black"},
**kwargs,
)
boxes.append(box)

for feature_idx, color in enumerate(colors):
for i, box in enumerate(boxes):
box["boxes"][feature_idx].set_facecolor(color)
for box in boxes:
patches = []
for patch, color in zip(box["boxes"], colors):
patch.set_facecolor(color)
patches.append(patch)

ax.set_xticks(
np.arange(0, len(bin_names)) + pos_increment * (len(feature_names) - 1) / 2
plt.xticks(
np.arange(0, len(bin_names)) + pos_increment * (len(feature_names) - 1) / 2,
bin_names,
)
ax.set_xticklabels(bin_names)

patches = [plt.Line2D([0], [0], color=color, lw=4) for color in colors]
ax.legend(
plt.legend(
patches,
feature_names,
loc="upper center",
loc="lower center",
bbox_to_anchor=(0.5, 1.05),
ncol=len(feature_names),
)
Expand Down
4 changes: 2 additions & 2 deletions sharp/visualization/_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._aggregate import strata_boxplots


class ShaRPViz: # TODO
class ShaRPViz:
def __init__(self, sharp):
self.sharp = sharp

Expand Down Expand Up @@ -54,7 +54,7 @@ def strata_boxplot(
feature_names=None,
n_strata=5,
gap_size=1,
cmap=None,
cmap="Pastel1",
ax=None,
**kwargs
):
Expand Down

0 comments on commit e3ee8a1

Please sign in to comment.