Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Owlv2; Add option to use owlvit / owlv2's image preprocessing procedure (the roi_align in current implementation processes images differently and results in distribution shifts and subpar performance) #23

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
14 changes: 11 additions & 3 deletions examples/owl_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@

parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg")
parser.add_argument("--prompt", type=str, default="an owl, a glove")
parser.add_argument("--prompt", type=str, default="[an owl, a glove]")
parser.add_argument("--threshold", type=str, default="0.1,0.1")
parser.add_argument("--nms_threshold", type=float, default=0.3)
parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg")
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
parser.add_argument('--no_roi_align', action='store_true')
parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num_profiling_runs", type=int, default=30)
Expand All @@ -45,13 +47,17 @@

thresholds = args.threshold.strip("][()")
thresholds = thresholds.split(',')
thresholds = [float(x) for x in thresholds]
if len(thresholds) == 1:
thresholds = float(thresholds[0])
else:
thresholds = [float(x) for x in thresholds]
print(thresholds)


predictor = OwlPredictor(
args.model,
image_encoder_engine=args.image_encoder_engine
image_encoder_engine=args.image_encoder_engine,
no_roi_align=args.no_roi_align
)

image = PIL.Image.open(args.image)
Expand All @@ -63,6 +69,7 @@
text=text,
text_encodings=text_encodings,
threshold=thresholds,
nms_threshold=args.nms_threshold,
pad_square=False
)

Expand All @@ -75,6 +82,7 @@
text=text,
text_encodings=text_encodings,
threshold=thresholds,
nms_threshold=args.nms_threshold,
pad_square=False
)
torch.cuda.current_stream().synchronize()
Expand Down
4 changes: 3 additions & 1 deletion examples/tree_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
parser.add_argument("--threshold", type=float, default=0.1)
parser.add_argument("--output", type=str, default="../data/tree_predict_out.jpg")
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
parser.add_argument('--no_roi_align', action='store_true')
parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
args = parser.parse_args()

predictor = TreePredictor(
owl_predictor=OwlPredictor(
args.model,
image_encoder_engine=args.image_encoder_engine
image_encoder_engine=args.image_encoder_engine,
no_roi_align=args.no_roi_align
)
)

Expand Down
59 changes: 53 additions & 6 deletions nanoowl/image_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@


import torch
import torchvision
import PIL.Image
import numpy as np
from typing import Tuple
from typing import Tuple, Optional, Union


__all__ = [
Expand All @@ -44,7 +45,10 @@
class ImagePreprocessor(torch.nn.Module):
def __init__(self,
mean: Tuple[float, float, float] = DEFAULT_IMAGE_PREPROCESSOR_MEAN,
std: Tuple[float, float, float] = DEFAULT_IMAGE_PREPROCESSOR_STD
std: Tuple[float, float, float] = DEFAULT_IMAGE_PREPROCESSOR_STD,
resize: Optional[Union[int, Tuple[int, int]]] = None,
resize_by_pad: bool = False,
padding_value: Optional[float] = 127.5,
):
super().__init__()

Expand All @@ -57,8 +61,47 @@ def __init__(self,
torch.tensor(std)[None, :, None, None]
)

def forward(self, image: torch.Tensor, inplace: bool = False):
if resize is not None and isinstance(resize, int):
resize = (resize, resize)
self.resize = resize
self.resize_by_pad = resize_by_pad
self.padding_value = padding_value
if (resize is not None) and (not resize_by_pad):
self.resizer = torchvision.transforms.Resize(
resize,
interpolation=torchvision.transforms.InterpolationMode.BICUBIC
)
else:
self.resizer = None

def forward(self, image: torch.Tensor, inplace: bool = False):

if self.resize:
if self.resizer is not None:
image = self.resizer(image)
if self.resize_by_pad:
if image.size(-1) <= self.resize[-1] and image.size(-2) <= self.resize[-2]:
image = torch.nn.functional.pad(
image,
[0, self.resize[-1] - image.size(-1), 0, self.resize[-2] - image.size(-2)],
"constant",
self.padding_value
)
else:
downsample_factor = max(image.size(-2) / self.resize[-2], image.size(-1) / self.resize[-1])
target_size = (round(image.size(-2) / downsample_factor), round(image.size(-1) / downsample_factor))
image = torchvision.transforms.functional.resize(
image,
target_size,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
image = torch.nn.functional.pad(
image,
[0, self.resize[-1] - image.size(-1), 0, self.resize[-2] - image.size(-2)],
"constant",
self.padding_value
)

if inplace:
image = image.sub_(self.mean).div_(self.std)
else:
Expand All @@ -67,9 +110,13 @@ def forward(self, image: torch.Tensor, inplace: bool = False):
return image

@torch.no_grad()
def preprocess_pil_image(self, image: PIL.Image.Image):
image = torch.from_numpy(np.asarray(image))
def preprocess_numpy_array(self, image: np.ndarray):
image = torch.from_numpy(image)
image = image.permute(2, 0, 1)[None, ...]
image = image.to(self.mean.device)
image = image.type(self.mean.dtype)
return self.forward(image, inplace=True)
return self.forward(image, inplace=True)

@torch.no_grad()
def preprocess_pil_image(self, image: PIL.Image.Image):
return self.preprocess_numpy_array(np.asarray(image))
6 changes: 3 additions & 3 deletions nanoowl/owl_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_colors(count: int):
def draw_owl_output(image, output: OwlDecodeOutput, text: List[str], draw_text=True):
is_pil = not isinstance(image, np.ndarray)
if is_pil:
image = np.asarray(image)
image = np.array(image)
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.75
colors = get_colors(len(text))
Expand All @@ -58,7 +58,7 @@ def draw_owl_output(image, output: OwlDecodeOutput, text: List[str], draw_text=T
if draw_text:
offset_y = 12
offset_x = 0
label_text = text[label_index]
label_text = text[label_index] + ' ' + f'{output.scores[i]:.2f}'
cv2.putText(
image,
label_text,
Expand All @@ -71,4 +71,4 @@ def draw_owl_output(image, output: OwlDecodeOutput, text: List[str], draw_text=T
)
if is_pil:
image = PIL.Image.fromarray(image)
return image
return image
Loading