diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index bbff295..fa5bdd6 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -1568,23 +1568,38 @@ def calibration_graph(self, dataset=None, verbose=True): title=f"Last rating: {last_rating}", ) - fig3 = self.calibration_helper( + fig3 = plt.figure() + self.calibration_helper( dataset[["stability", "p", "y"]].copy(), "stability", lambda x: math.pow(1.2, math.floor(math.log(x, 1.2))), True, + fig3.add_subplot(111), ) - fig4 = self.calibration_helper( + + fig4 = plt.figure(figsize=(16, 12)) + for last_rating in (1, 2, 3, 4): + calibration_data = dataset[dataset["last_rating"] == last_rating] + if calibration_data.empty: + continue + self.calibration_helper( + calibration_data[["stability", "p", "y"]].copy(), + "stability", + lambda x: math.pow(1.2, math.floor(math.log(x, 1.2))), + True, + fig4.add_subplot(2, 2, int(last_rating)), + ) + fig5 = plt.figure() + self.calibration_helper( dataset[["difficulty", "p", "y"]].copy(), "difficulty", lambda x: round(x), False, + fig5.add_subplot(111), ) - return metrics, (fig1, fig2, fig3, fig4) + return metrics, (fig1, fig2, fig3, fig4, fig5) - def calibration_helper(self, calibration_data, key, bin_func, semilogx): - fig = plt.figure() - ax1 = fig.add_subplot(111) + def calibration_helper(self, calibration_data, key, bin_func, semilogx, ax1): ax2 = ax1.twinx() lns = [] @@ -1621,7 +1636,7 @@ def to_percent(temp, position): ax2.legend(lns, labs, loc="lower right") ax2.grid(linestyle="--") ax2.yaxis.set_major_formatter(ticker.FuncFormatter(to_percent)) - return fig + return ax1 def formula_analysis(self): analysis_df = self.dataset[self.dataset["i"] > 2].copy()