Skip to content

Commit

Permalink
Fix/don't print anything if verbose=True (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Mar 25, 2024
1 parent 07ae177 commit 2ce4924
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.26.7"
version = "4.26.8"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
64 changes: 34 additions & 30 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def build_dataset(self, train_set: pd.DataFrame, test_set: pd.DataFrame):

self.test_set = BatchDataset(test_set, batch_size=self.batch_size)
self.test_data_loader = BatchLoader(self.test_set)
tqdm.write("dataset built")

def train(self, verbose: bool = True):
self.verbose = verbose
Expand Down Expand Up @@ -820,7 +819,7 @@ def pretrain(self, dataset=None, verbose=True):
group = self.S0_dataset_group[
self.S0_dataset_group["r_history"] == first_rating
]
if group.empty:
if group.empty and verbose:
tqdm.write(
f"Not enough data for first rating {first_rating}. Expected at least 1, got 0."
)
Expand Down Expand Up @@ -976,8 +975,8 @@ def loss(stability):
]

self.init_w[0:4] = list(map(lambda x: max(min(100, x), S_MIN), init_s0))

tqdm.write(f"Pretrain finished!")
if verbose:
tqdm.write(f"Pretrain finished!")
return plots

def train(
Expand All @@ -995,7 +994,8 @@ def train(
axis=1,
)
self.dataset["group"] = self.dataset["r_history"] + self.dataset["t_history"]
tqdm.write("Tensorized!")
if verbose:
tqdm.write("Tensorized!")

w = []
plots = []
Expand All @@ -1004,7 +1004,8 @@ def train(
tscv = TimeSeriesSplit(n_splits=n_splits)
self.dataset.sort_values(by=["review_time"], inplace=True)
for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)):
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
if verbose:
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
train_set = self.dataset.iloc[train_index].copy()
test_set = self.dataset.iloc[test_index].copy()
trainer = Trainer(
Expand All @@ -1021,18 +1022,19 @@ def train(
metrics, figures = self.calibration_graph(
self.dataset.iloc[test_index]
)
print(metrics)
for j, f in enumerate(figures):
f.savefig(f"graph_{j}_test_{i}.png")
plt.close(f)
if verbose:
print(metrics)
plots.append(trainer.plot())
else:
sgkf = StratifiedGroupKFold(n_splits=n_splits)
for train_index, test_index in sgkf.split(
self.dataset, self.dataset["i"], self.dataset["group"]
):
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
if verbose:
tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}")
train_set = self.dataset.iloc[train_index].copy()
test_set = self.dataset.iloc[test_index].copy()
trainer = Trainer(
Expand Down Expand Up @@ -1063,7 +1065,8 @@ def train(
avg_w = np.round(np.mean(w, axis=0), 4)
self.w = avg_w.tolist()

tqdm.write("\nTraining finished!")
if verbose:
tqdm.write("\nTraining finished!")
return plots

def preview(self, requestRetention: float, verbose=False):
Expand Down Expand Up @@ -1185,7 +1188,6 @@ def predict_memory_states(self):
prediction.sort_values(by=["r_history"], inplace=True)
prediction.rename(columns={"review_time": "count"}, inplace=True)
prediction.to_csv("./prediction.tsv", sep="\t", index=None)
tqdm.write("prediction.tsv saved.")
prediction["difficulty"] = prediction["difficulty"].map(lambda x: int(round(x)))
self.difficulty_distribution = (
prediction.groupby(by=["difficulty"])["count"].sum()
Expand Down Expand Up @@ -1237,22 +1239,22 @@ def find_optimal_retention(
+ recall_cost,
1,
)

tqdm.write(f"average time for failed reviews: {forget_cost}s")
tqdm.write(f"average time for recalled reviews: {recall_cost}s")
tqdm.write(
"average time for `hard`, `good` and `easy` reviews: %.1fs, %.1fs, %.1fs"
% tuple(self.recall_costs)
)
tqdm.write(f"average time for learning a new card: {self.learn_cost}s")
tqdm.write(
"Ratio of `hard`, `good` and `easy` ratings for recalled reviews: %.2f, %.2f, %.2f"
% tuple(self.review_rating_prob)
)
tqdm.write(
"Ratio of `again`, `hard`, `good` and `easy` ratings for new cards: %.2f, %.2f, %.2f, %.2f"
% tuple(self.first_rating_prob)
)
if verbose:
tqdm.write(f"average time for failed reviews: {forget_cost}s")
tqdm.write(f"average time for recalled reviews: {recall_cost}s")
tqdm.write(
"average time for `hard`, `good` and `easy` reviews: %.1fs, %.1fs, %.1fs"
% tuple(self.recall_costs)
)
tqdm.write(f"average time for learning a new card: {self.learn_cost}s")
tqdm.write(
"Ratio of `hard`, `good` and `easy` ratings for recalled reviews: %.2f, %.2f, %.2f"
% tuple(self.review_rating_prob)
)
tqdm.write(
"Ratio of `again`, `hard`, `good` and `easy` ratings for new cards: %.2f, %.2f, %.2f, %.2f"
% tuple(self.first_rating_prob)
)

forget_cost *= loss_aversion

Expand Down Expand Up @@ -1387,12 +1389,13 @@ def evaluate(self, save_to_file=True):
del tmp
return loss_before, loss_after

def calibration_graph(self, dataset=None):
def calibration_graph(self, dataset=None, verbose=True):
if dataset is None:
dataset = self.dataset
fig1 = plt.figure()
rmse = rmse_matrix(dataset)
tqdm.write(f"RMSE(bins): {rmse:.4f}")
if verbose:
tqdm.write(f"RMSE(bins): {rmse:.4f}")
metrics = plot_brier(
dataset["p"], dataset["y"], bins=20, ax=fig1.add_subplot(111)
)
Expand All @@ -1402,9 +1405,10 @@ def calibration_graph(self, dataset=None):
calibration_data = dataset[dataset["r_history"].str.endswith(last_rating)]
if calibration_data.empty:
continue
tqdm.write(f"\nLast rating: {last_rating}")
rmse = rmse_matrix(calibration_data)
tqdm.write(f"RMSE(bins): {rmse:.4f}")
if verbose:
tqdm.write(f"\nLast rating: {last_rating}")
tqdm.write(f"RMSE(bins): {rmse:.4f}")
plot_brier(
calibration_data["p"],
calibration_data["y"],
Expand Down

0 comments on commit 2ce4924

Please sign in to comment.