Skip to content

Commit

Permalink
Merge branch 'main' into merge_ordered_validation
Browse files Browse the repository at this point in the history
  • Loading branch information
ziad-kermadi authored Nov 13, 2023
2 parents 1dfc85d + 76d28c7 commit 228c724
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 21 deletions.
72 changes: 52 additions & 20 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ def __init__(
layout=None,
include_bool: bool = False,
column: IndexLabel | None = None,
*,
logx: bool | None | Literal["sym"] = False,
logy: bool | None | Literal["sym"] = False,
loglog: bool | None | Literal["sym"] = False,
mark_right: bool = True,
stacked: bool = False,
label: Hashable | None = None,
style=None,
**kwds,
) -> None:
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -230,13 +238,13 @@ def __init__(
self.legend_handles: list[Artist] = []
self.legend_labels: list[Hashable] = []

self.logx = kwds.pop("logx", False)
self.logy = kwds.pop("logy", False)
self.loglog = kwds.pop("loglog", False)
self.label = kwds.pop("label", None)
self.style = kwds.pop("style", None)
self.mark_right = kwds.pop("mark_right", True)
self.stacked = kwds.pop("stacked", False)
self.logx = type(self)._validate_log_kwd("logx", logx)
self.logy = type(self)._validate_log_kwd("logy", logy)
self.loglog = type(self)._validate_log_kwd("loglog", loglog)
self.label = label
self.style = style
self.mark_right = mark_right
self.stacked = stacked

# ax may be an Axes object or (if self.subplots) an ndarray of
# Axes objects
Expand Down Expand Up @@ -292,6 +300,22 @@ def _validate_sharex(sharex: bool | None, ax, by) -> bool:
raise TypeError("sharex must be a bool or None")
return bool(sharex)

@classmethod
def _validate_log_kwd(
cls,
kwd: str,
value: bool | None | Literal["sym"],
) -> bool | None | Literal["sym"]:
if (
value is None
or isinstance(value, bool)
or (isinstance(value, str) and value == "sym")
):
return value
raise ValueError(
f"keyword '{kwd}' should be bool, None, or 'sym', not '{value}'"
)

@final
@staticmethod
def _validate_subplots_kwarg(
Expand Down Expand Up @@ -556,14 +580,6 @@ def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]:

axes = flatten_axes(axes)

valid_log = {False, True, "sym", None}
input_log = {self.logx, self.logy, self.loglog}
if input_log - valid_log:
invalid_log = next(iter(input_log - valid_log))
raise ValueError(
f"Boolean, None and 'sym' are valid options, '{invalid_log}' is given."
)

if self.logx is True or self.loglog is True:
[a.set_xscale("log") for a in axes]
elif self.logx == "sym" or self.loglog == "sym":
Expand Down Expand Up @@ -1334,7 +1350,12 @@ def _make_plot(self, fig: Figure):
cbar.ax.set_yticklabels(self.data[c].cat.categories)

if label is not None:
self._append_legend_handles_labels(scatter, label)
self._append_legend_handles_labels(
# error: Argument 2 to "_append_legend_handles_labels" of
# "MPLPlot" has incompatible type "Hashable"; expected "str"
scatter,
label, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues]
)

errors_x = self._get_errorbars(label=x, index=0, yerr=False)
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
Expand Down Expand Up @@ -1999,10 +2020,21 @@ def __init__(self, data, kind=None, **kwargs) -> None:
if (data < 0).any().any():
raise ValueError(f"{self._kind} plot doesn't allow negative values")
MPLPlot.__init__(self, data, kind=kind, **kwargs)
self.grid = False
self.logy = False
self.logx = False
self.loglog = False

@classmethod
def _validate_log_kwd(
cls,
kwd: str,
value: bool | None | Literal["sym"],
) -> bool | None | Literal["sym"]:
super()._validate_log_kwd(kwd=kwd, value=value)
if value is not False:
warnings.warn(
f"PiePlot ignores the '{kwd}' keyword",
UserWarning,
stacklevel=find_stack_level(),
)
return False

def _validate_color_args(self, color, colormap) -> None:
# TODO: warn if color is passed and ignored?
Expand Down
6 changes: 5 additions & 1 deletion pandas/tests/plotting/frame/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,10 +333,14 @@ def test_invalid_logscale(self, input_param):
# GH: 24867
df = DataFrame({"a": np.arange(100)}, index=np.arange(100))

msg = "Boolean, None and 'sym' are valid options, 'sm' is given."
msg = f"keyword '{input_param}' should be bool, None, or 'sym', not 'sm'"
with pytest.raises(ValueError, match=msg):
df.plot(**{input_param: "sm"})

msg = f"PiePlot ignores the '{input_param}' keyword"
with tm.assert_produces_warning(UserWarning, match=msg):
df.plot.pie(subplots=True, **{input_param: True})

def test_xcompat(self):
df = tm.makeTimeDataFrame()
ax = df.plot(x_compat=True)
Expand Down

0 comments on commit 228c724

Please sign in to comment.