Skip to content

Commit

Permalink
perf: compute embedding in same block as model output to avoid forwar…
Browse files Browse the repository at this point in the history
…d compute twice
  • Loading branch information
danellecline committed Dec 20, 2024
1 parent fab4296 commit 86f0bd9
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions biotrack/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from biotrack.logger import info, err, debug
from biotrack.track import Track

from transformers import AutoModelForImageClassification, AutoImageProcessor
from transformers import AutoConfig, AutoModelForImageClassification, AutoImageProcessor
from torchvision import transforms
import torch.nn.functional as F

Expand All @@ -33,7 +33,8 @@ class ViTWrapper:
def __init__(self, batch_size: int = 32, model_name: str = DEFAULT_MODEL_NAME, device_id: int = 0):
self.batch_size = batch_size
self.name = model_name
self.model = AutoModelForImageClassification.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
self.model = AutoModelForImageClassification.from_pretrained(model_name, config=config)
self.processor = AutoImageProcessor.from_pretrained(model_name)

if not Path(model_name).exists():
Expand Down Expand Up @@ -65,6 +66,8 @@ def process_images(self, image_paths: list) -> tuple:
with torch.no_grad():
outputs = self.model(inputs)
logits = outputs.logits
embeddings = outputs.hidden_states[-1]
batch_embeddings = embeddings[:, 0, :].cpu().numpy()

# Get the top 5 classes and scores
top_scores, top_classes = torch.topk(logits, 5)
Expand Down Expand Up @@ -109,10 +112,6 @@ def process_images(self, image_paths: list) -> tuple:
err(f"No keypoints found for {image_path}")
continue

with torch.no_grad():
embeddings = self.model.base_model(inputs)
batch_embeddings = embeddings.last_hidden_state[:, 0, :].cpu().numpy()

predicted_classes = [[self.model.config.id2label[class_idx] for class_idx in class_list] for class_list in top_classes]
predicted_scores = [[score for score in score_list] for score_list in top_scores]

Expand Down

0 comments on commit 86f0bd9

Please sign in to comment.