Skip to content

Commit

Permalink
option to include short-term reviews in other.py
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Jul 25, 2024
1 parent df29134 commit cd9fc1a
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions other.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
verbose: bool = False

model_name = os.environ.get("MODEL", "FSRSv3")
short_term = os.environ.get("SHORT")
file_name = model_name + ("-short" if short_term else "")


class FSRS3ParameterClipper:
Expand Down Expand Up @@ -1649,7 +1651,9 @@ def baseline(file):


def create_features(df, model_name="FSRSv3"):
df = df[(df["delta_t"] != 0) & (df["rating"].isin([1, 2, 3, 4]))].copy()
df = df[df["rating"].isin([1, 2, 3, 4])].copy()
if short_term is None:
df = df[df["delta_t"] != 0].copy()
df["delta_t"] = df["delta_t"].map(lambda x: max(0, x))
df["i"] = df.groupby("card_id").cumcount() + 1
t_history = df.groupby("card_id", group_keys=False)["delta_t"].apply(
Expand Down Expand Up @@ -1771,6 +1775,9 @@ def r_history_to_l_history(r_history):

df["first_rating"] = df["r_history"].map(lambda x: x[0] if len(x) > 0 else "")
df["y"] = df["rating"].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x])
if short_term:
df = df[(df["delta_t"] != 0) | (df["i"] == 1)].copy()
df["i"] = df.groupby("card_id").cumcount() + 1
filtered_dataset = (
df[df["i"] == 2]
.groupby(by=["first_rating"], as_index=False, group_keys=False)[df.columns]
Expand Down Expand Up @@ -1864,20 +1871,20 @@ def process(args):
del save_tmp["tensor"]
if os.environ.get("FILE"):
save_tmp.to_csv(
f"evaluation/{model_name}/{file.stem}.tsv", sep="\t", index=False
f"evaluation/{file_name}/{file.stem}.tsv", sep="\t", index=False
)

result, raw = evaluate(
y, p, save_tmp, model_name, file, w_list if type(w_list[0]) == list else None
y, p, save_tmp, file_name, file, w_list if type(w_list[0]) == list else None
)
return result, raw


def evaluate(y, p, df, model_name, file, w_list=None):
def evaluate(y, p, df, file_name, file, w_list=None):
if os.environ.get("PLOT"):
fig = plt.figure()
plot_brier(p, y, ax=fig.add_subplot(111))
fig.savefig(f"evaluation/{model_name}/{file.stem}.png")
fig.savefig(f"evaluation/{file_name}/{file.stem}.png")
p_calibrated = lowess(
y, p, it=0, delta=0.01 * (max(p) - min(p)), return_sorted=False
)
Expand Down Expand Up @@ -1919,11 +1926,11 @@ def evaluate(y, p, df, model_name, file, w_list=None):
dataset_path0 = "./dataset/"
dataset_path1 = "../FSRS-Anki-20k/dataset/1/"
dataset_path2 = "../FSRS-Anki-20k/dataset/2/"
Path(f"evaluation/{model_name}").mkdir(parents=True, exist_ok=True)
Path(f"evaluation/{file_name}").mkdir(parents=True, exist_ok=True)
Path("result").mkdir(parents=True, exist_ok=True)
Path("raw").mkdir(parents=True, exist_ok=True)
result_file = Path(f"result/{model_name}.jsonl")
raw_file = Path(f"raw/{model_name}.jsonl")
result_file = Path(f"result/{file_name}.jsonl")
raw_file = Path(f"raw/{file_name}.jsonl")
if result_file.exists():
data = sort_jsonl(result_file)
processed_user = set(map(lambda x: x["user"], data))
Expand Down

0 comments on commit cd9fc1a

Please sign in to comment.