forked from giswqs/GroundingDINO
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature/first_batch_of_model_usability_upgrades (IDEA-Research#9)
* initial commit * test updated requirements.txt * move more code to inference utils * PIL import fix * add annotations utilities * README.md updates
- Loading branch information
Showing
4 changed files
with
131 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# IDE | ||
.idea/ | ||
.vscode/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from typing import Tuple, List | ||
|
||
import cv2 | ||
import numpy as np | ||
import supervision as sv | ||
import torch | ||
from PIL import Image | ||
from torchvision.ops import box_convert | ||
|
||
import groundingdino.datasets.transforms as T | ||
from groundingdino.models import build_model | ||
from groundingdino.util.misc import clean_state_dict | ||
from groundingdino.util.slconfig import SLConfig | ||
from groundingdino.util.utils import get_phrases_from_posmap | ||
|
||
|
||
def preprocess_caption(caption: str) -> str: | ||
result = caption.lower().strip() | ||
if result.endswith("."): | ||
return result | ||
return result + "." | ||
|
||
|
||
def load_model(model_config_path: str, model_checkpoint_path: str): | ||
args = SLConfig.fromfile(model_config_path) | ||
args.device = "cuda" | ||
model = build_model(args) | ||
checkpoint = torch.load(model_checkpoint_path, map_location="cpu") | ||
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) | ||
model.eval() | ||
return model | ||
|
||
|
||
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]: | ||
transform = T.Compose( | ||
[ | ||
T.RandomResize([800], max_size=1333), | ||
T.ToTensor(), | ||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | ||
] | ||
) | ||
image_source = Image.open(image_path).convert("RGB") | ||
image = np.asarray(image_source) | ||
image_transformed, _ = transform(image_source, None) | ||
return image, image_transformed | ||
|
||
|
||
def predict( | ||
model, | ||
image: torch.Tensor, | ||
caption: str, | ||
box_threshold: float, | ||
text_threshold: float | ||
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: | ||
caption = preprocess_caption(caption=caption) | ||
|
||
model = model.cuda() | ||
image = image.cuda() | ||
|
||
with torch.no_grad(): | ||
outputs = model(image[None], captions=[caption]) | ||
|
||
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) | ||
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) | ||
|
||
mask = prediction_logits.max(dim=1)[0] > box_threshold | ||
logits = prediction_logits[mask] # logits.shape = (n, 256) | ||
boxes = prediction_boxes[mask] # boxes.shape = (n, 4) | ||
|
||
tokenizer = model.tokenizer | ||
tokenized = tokenizer(caption) | ||
|
||
phrases = [ | ||
get_phrases_from_posmap(logit > text_threshold, tokenized, caption).replace('.', '') | ||
for logit | ||
in logits | ||
] | ||
|
||
return boxes, logits.max(dim=1)[0], phrases | ||
|
||
|
||
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str]) -> np.ndarray: | ||
h, w, _ = image_source.shape | ||
boxes = boxes * torch.Tensor([w, h, w, h]) | ||
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() | ||
detections = sv.Detections(xyxy=xyxy) | ||
|
||
labels = [ | ||
f"{phrase} {logit:.2f}" | ||
for phrase, logit | ||
in zip(phrases, logits) | ||
] | ||
|
||
box_annotator = sv.BoxAnnotator() | ||
annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR) | ||
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) | ||
return annotated_frame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,9 @@ | ||
transformers==4.5.1 | ||
torch | ||
torchvision | ||
transformers | ||
addict | ||
yapf | ||
timm | ||
numpy | ||
opencv-python | ||
supervision==0.3.2 |