-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_ablation_study.py
196 lines (163 loc) · 6.72 KB
/
run_ablation_study.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import copy
import json
from pathlib import Path
from typing import List
import numpy as np
import tensorflow_datasets as tfds
import typer
from rich import print
from sklearn.metrics import classification_report
import datasets.off_categories
from lib.constant import IMAGE_EMBEDDING_DIM, MAX_IMAGE_EMBEDDING, NUTRIMENT_NAMES
from lib.dataset import load_dataset, select_feature
from lib.io import load_model
from lib.metrics import PrecisionWithAverage, RecallWithAverage
from lib.preprocessing import fix_image_embeddings_mask
PREPROC_BATCH_SIZE = 25_000 # some large value, only affects execution time
def extract_barcodes(ds) -> List[str]:
"Extract all barcodes in sequential order from a dataset."
barcodes = []
for batch in tfds.as_numpy(select_feature(ds, "code").batch(PREPROC_BATCH_SIZE)):
barcodes += [x.decode("utf-8") for x in batch.tolist()]
return barcodes
def generate_y_true(shape, ds, labels: List[str]):
y_true = np.zeros(shape, dtype=int)
label_to_idx = {label: i for i, label in enumerate(labels)}
categories_tags_all = (
[cat.decode("utf-8") for cat in x["categories_tags"].numpy().tolist()]
for x in ds
)
for i, categories_tags in enumerate(categories_tags_all):
for category_tag in categories_tags:
if category_tag in label_to_idx:
y_true[i, label_to_idx[category_tag]] = 1
return y_true
def remove_ingredients_ocr_func(x):
x = copy.copy(x)
x["ingredients_ocr_tags"] = np.array([], dtype=np.string_)
return x
def remove_ingredients_func(x):
x = copy.copy(x)
x["ingredients_tags"] = np.array([], dtype=np.string_)
return x
def remove_nutriments_func(x):
x = copy.copy(x)
for nutriment_name in NUTRIMENT_NAMES:
x[nutriment_name] = np.array(-1, dtype=np.float32)
return x
def remove_product_name_func(x):
x = copy.copy(x)
x["product_name"] = np.array("", dtype=np.string_)
return x
def remove_image_embeddings_func(x):
x = copy.copy(x)
x["image_embeddings"] = np.zeros(
(MAX_IMAGE_EMBEDDING, IMAGE_EMBEDDING_DIM), dtype=np.float32
)
x["image_embeddings_mask"] = np.array(
[1] + [0] * (MAX_IMAGE_EMBEDDING - 1), dtype=np.int64
)
return x
def main(
model_dir: Path = typer.Option(..., help="name of the model"),
remove_ingredient_ocr_tags: bool = typer.Option(
False, help="Remove OCR ingredient tags input from dataset (ablation)"
),
remove_ingredients_tags: bool = typer.Option(
False, help="Remove ingredient tags input from dataset (ablation)"
),
remove_nutriments: bool = typer.Option(
False, help="Remove nutriment inputs from dataset (ablation)"
),
remove_product_name: bool = typer.Option(
False, help="Remove product name inputs from dataset (ablation)"
),
remove_image_embedding: bool = typer.Option(
False, help="Remove image embedding inputs from dataset (ablation)"
),
):
MODEL_BASE_DIR = model_dir.parent
TRAIN_SPLIT = "train[:80%]"
VAL_SPLIT = "train[80%:90%]"
TEST_SPLIT = "train[90%:]"
print("checking training splits...")
split_barcodes = {}
SPLIT_DIR = MODEL_BASE_DIR / "splits"
missing_splits = not SPLIT_DIR.exists()
for split_name, split_command in (
("train", TRAIN_SPLIT),
("val", VAL_SPLIT),
("test", TEST_SPLIT),
):
print(f"checking split {split_name}")
barcodes = extract_barcodes(load_dataset("off_categories", split=split_command))
split_barcodes[split_name] = set(barcodes)
if len(split_barcodes[split_name]) != len(barcodes):
raise ValueError("duplicate products in %s split", split_name)
if missing_splits:
SPLIT_DIR.mkdir(exist_ok=True)
(SPLIT_DIR / f"{split_name}.txt").write_text("\n".join(barcodes))
else:
expected_barcodes = (
(SPLIT_DIR / f"{split_name}.txt").read_text().splitlines()
)
if barcodes != expected_barcodes:
raise ValueError(
"barcodes for split %s did not match reference", split_name
)
for split_1, split_2 in (("train", "val"), ("train", "test"), ("val", "test")):
if split_barcodes[split_1].intersection(split_barcodes[split_2]):
raise ValueError("splits %s and %s intersect", split_1, split_2)
print("Downloading and preparing dataset...")
builder = tfds.builder("off_categories")
builder.download_and_prepare()
SAVED_MODEL_DIR = model_dir / "saved_model"
m, labels = load_model(
SAVED_MODEL_DIR,
custom_objects={
"PrecisionWithAverage": PrecisionWithAverage,
"RecallWithAverage": RecallWithAverage,
},
)
has_image_embedding = any(i.name == "image_embeddings" for i in m.inputs)
for split_name, split_command in (("val", VAL_SPLIT), ("test", TEST_SPLIT)):
split_ds = load_dataset("off_categories", split=split_command).apply(
fix_image_embeddings_mask if has_image_embedding else lambda ds: ds
)
suffixes = []
if remove_ingredient_ocr_tags:
suffixes.append("ingredients_ocr_tags")
split_ds = split_ds.map(remove_ingredients_ocr_func)
if remove_ingredients_tags:
suffixes.append("ingredients_tags")
split_ds = split_ds.map(remove_ingredients_func)
if remove_nutriments:
suffixes.append("nutriments")
split_ds = split_ds.map(remove_nutriments_func)
if remove_product_name:
suffixes.append("product_name")
split_ds = split_ds.map(remove_product_name_func)
if remove_image_embedding:
suffixes.append("image_embeddings")
split_ds = split_ds.map(remove_image_embeddings_func)
y_pred = m.predict(split_ds.padded_batch(32))
y_pred_binary = np.zeros(y_pred.shape, dtype=int)
y_pred_binary[y_pred >= 0.5] = 1
y_true = generate_y_true(y_pred.shape, split_ds, labels)
ablation_dir = model_dir / "ablations"
ablation_dir.mkdir(exist_ok=True)
suffixes.append(split_name)
suffix = "_".join(suffixes)
output_report_path = ablation_dir / f"classification_report_{suffix}.json"
if not output_report_path.exists():
print("generating classification report...")
metrics = classification_report(
y_true, y_pred_binary, target_names=labels, output_dict=True
)
metrics = dict(
sorted(metrics.items(), key=lambda x: x[1]["support"], reverse=True)
)
with output_report_path.open("w") as f:
json.dump(metrics, f, indent=4)
if __name__ == "__main__":
typer.run(main)