Skip to content

Commit

Permalink
feat: complete usable model
Browse files Browse the repository at this point in the history
  • Loading branch information
matteopolak committed Jul 2, 2024
1 parent 536f4af commit 8399c29
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 36 deletions.
3 changes: 1 addition & 2 deletions grill/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

with open("data/classes.pkl", "rb") as f:
Expand All @@ -26,7 +25,7 @@
class RecipeImageDataset(Dataset):
def __init__(self, parquet_file, img_dir, transform=None, partition="train"):
self.annotations = (pl.read_parquet(parquet_file)
.filter("parititon" == partition)
.filter(pl.col("partition") == partition)
.with_row_index())

self.img_dir = img_dir
Expand Down
6 changes: 2 additions & 4 deletions grill/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def format_prediction(prediction: torch.Tensor) -> str:
predicted_classes = []

for i, v in enumerate(prediction):
if v > 0.7:
predicted_classes.append(classes[i])
if v > 0.96:
predicted_classes.append(f"{classes[i]} ({v:.2f})")

return ', '.join(predicted_classes)

Expand All @@ -53,5 +53,3 @@ def format_prediction(prediction: torch.Tensor) -> str:
plt.imshow(image)
plt.show()

break

13 changes: 0 additions & 13 deletions grill/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,5 @@

torch.save(model.state_dict(), f"checkpoints/grill-epoch{epoch}.pth")

accuracy = 0.0
total = 0

with torch.no_grad():
for images, labels in dataloader:
outputs = model(images)
predicted = outputs.sigmoid()

total += labels.size(0)
accuracy += (predicted - labels).abs().sum().item()

logger.info(f"Epoch {epoch}, Accuracy: {accuracy/total}")

torch.save(model.state_dict(), "models/grill.pth")

50 changes: 33 additions & 17 deletions prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,47 @@
.list.eval(pl.element()
.struct.field("text")
.str.to_lowercase()
.str.strip_chars()
# remove text between parentheses, and (if the ingredient starts with a number),
# remove the number and the next word following it (if there are multiple numbers in a row, keep removing
# until there's a word). e.g. "1 1/2 cup sugar" -> "sugar"
.str.replace_all(r"\([^()]*\)|,.*|\"|\'|^\w\b", "")
.str.replace_all(
r"\([^()]*\)|,.*|\"|\'",
""
)
.str.replace_all(
r"^\s*[-\d][^\s]*\s+(?:[-\d][^\s]+\s+)*[^\s]*\s*([^\-\d\s])",
r"^\s*[-\d/][^\s]*\s+(?:[-\d/][^\s]+\s+)*[^\s]*\s*([^\-\d\s/])",
"${1}"
))
)
.str.replace_all(r"\.", "")
.str.replace_all(r"\bnull\b", "")
.str.strip_chars()
.str.strip_prefix("cup ")
.str.strip_prefix("cups ")
.str.strip_prefix("tbsp ")
.str.strip_prefix("tablespoon ")
.str.strip_prefix("tablespoons ")
.str.strip_prefix("tsp ")
.str.strip_prefix("lbs ")
.str.strip_prefix("lb ")
.str.strip_prefix("dash ")
.str.strip_chars()
.str.strip_prefix("of ")
.str.replace_all(r"\s{2,}", " ")
.str.strip_chars())
.list.eval(pl.element()
.filter(pl.element().str.len_bytes().is_between(2, 15) & ~pl.element().str.ends_with("ed")))
.alias("ingredients")).to_series()
)
.list.unique()
.alias("ingredients")).to_series())

classes = (df.get_column("ingredients")
.explode().alias("ingredient")
.value_counts(sort=True)
.filter(pl.col("count") >= 100))
.value_counts()
.filter(pl.col("count") >= 80))

# remove ingredients that are not in the classes
df.replace_column(
df.get_column_index("ingredients"),
df.select(pl.col("ingredients")
.list.eval(pl.element().filter(pl.element().is_in(classes.get_column("ingredient"))))).to_series())

df.filter(pl.col("ingredients").list.len() > 5)

num_recipes = len(df)
df = df.filter(pl.col("ingredients").list.len() > 6)

df.write_parquet("data/annotations.parquet")

Expand All @@ -64,11 +74,17 @@
# so it can be used as a weight in the loss function
classes = (df.get_column("ingredients")
.explode().alias("ingredient")
.value_counts()
.filter(pl.col("count") > 100))
.value_counts(sort=True)
.filter(pl.col("count") > 80)
.drop_nulls())

n_recipes = len(df)
classes = classes.replace_column(
classes.get_column_index("count"),
classes.select(pl.lit(num_recipes) / pl.col("count")).to_series())
classes.select((pl.lit(n_recipes) / pl.col("count")).alias("count")).to_series())

classes = classes.to_dict()
classes = dict(zip(classes["ingredient"], classes["count"]))

with open("data/classes.pkl", "wb") as f:
pickle.dump(classes, f)
Expand Down

0 comments on commit 8399c29

Please sign in to comment.