From 86f0bd90e67772db1e664a63581f57a8eb7c7a9a Mon Sep 17 00:00:00 2001 From: danellecline Date: Fri, 20 Dec 2024 09:45:56 -0800 Subject: [PATCH] perf: compute embedding in same block as model output to avoid forward compute twice --- biotrack/embedding.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/biotrack/embedding.py b/biotrack/embedding.py index 9e1412f..c374bcf 100644 --- a/biotrack/embedding.py +++ b/biotrack/embedding.py @@ -16,7 +16,7 @@ from biotrack.logger import info, err, debug from biotrack.track import Track -from transformers import AutoModelForImageClassification, AutoImageProcessor +from transformers import AutoConfig, AutoModelForImageClassification, AutoImageProcessor from torchvision import transforms import torch.nn.functional as F @@ -33,7 +33,8 @@ class ViTWrapper: def __init__(self, batch_size: int = 32, model_name: str = DEFAULT_MODEL_NAME, device_id: int = 0): self.batch_size = batch_size self.name = model_name - self.model = AutoModelForImageClassification.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name, output_hidden_states=True) + self.model = AutoModelForImageClassification.from_pretrained(model_name, config=config) self.processor = AutoImageProcessor.from_pretrained(model_name) if not Path(model_name).exists(): @@ -65,6 +66,8 @@ def process_images(self, image_paths: list) -> tuple: with torch.no_grad(): outputs = self.model(inputs) logits = outputs.logits + embeddings = outputs.hidden_states[-1] + batch_embeddings = embeddings[:, 0, :].cpu().numpy() # Get the top 5 classes and scores top_scores, top_classes = torch.topk(logits, 5) @@ -109,10 +112,6 @@ def process_images(self, image_paths: list) -> tuple: err(f"No keypoints found for {image_path}") continue - with torch.no_grad(): - embeddings = self.model.base_model(inputs) - batch_embeddings = embeddings.last_hidden_state[:, 0, :].cpu().numpy() - predicted_classes = [[self.model.config.id2label[class_idx] for class_idx in class_list] for class_list in top_classes] predicted_scores = [[score for score in score_list] for score_list in top_scores]