Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/gridcontext #871

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions experiments/fastsamsom.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def main() -> None:
segmented_image = segmentation_adapter.fetch_segmented_image(
image,
# threshold below which boxes will be filtered out
conf=0,
min_confidence_threshold=0,
# discards all overlapping boxes with IoU > iou_threshold
iou=0.05,
max_iou_threshold=0.05,
)
if DEBUG:
segmented_image.show()
Expand Down
152 changes: 152 additions & 0 deletions experiments/gridcontext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from PIL import Image, ImageEnhance, ImageOps, ImageFilter
import cv2
import numpy as np
from openadapt import adapters, config, utils

def apply_augmentations(image: Image.Image) -> list[Image.Image]:
"""
Applies a series of augmentations to the image and returns a list of augmented images.

Args:
image (Image.Image): The original image.

Returns:
list[Image.Image]: List of augmented images.
"""
augmented_images = []

# Original image
augmented_images.append(image)

"""
# increase contrast
enhancer = ImageEnhance.Contrast(image)
enhanced_image = enhancer.enhance(1000)

# Increase sharpness
sharpness_enhancer = ImageEnhance.Sharpness(image)
augmented_images.append(sharpness_enhancer.enhance(100)) # Sharper image

# Adaptive Histogram Equalization (CLAHE)
clahe_image = apply_clahe(image)
augmented_images.append(clahe_image)

# Edge Enhancement
edge_enhanced_image = image.filter(ImageFilter.EDGE_ENHANCE)
augmented_images.append(edge_enhanced_image)
"""

# invert
inverted_image = ImageOps.invert(image.convert("RGB"))
augmented_images.append(inverted_image)

enhancer = ImageEnhance.Contrast(inverted_image)
enhanced_image = enhancer.enhance(1000)
augmented_images.append(enhanced_image)

return augmented_images

def apply_clahe(image: Image.Image) -> Image.Image:
"""
Applies CLAHE (adaptive histogram equalization) to the image.

Args:
image (Image.Image): The original image.

Returns:
Image.Image: Image after applying CLAHE.
"""
# Convert to OpenCV format (BGR)
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)

# Convert to grayscale
gray_image = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)

# Apply CLAHE
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
equalized_image = clahe.apply(gray_image)

# Convert back to RGB and PIL format
equalized_image = cv2.cvtColor(equalized_image, cv2.COLOR_GRAY2RGB)
return Image.fromarray(equalized_image)

def _main():
image_file_path = config.ROOT_DIR_PATH / "../tests/assets/excel.png"
image = Image.open(image_file_path)

segmentation_adapter = adapters.get_default_segmentation_adapter()
all_annotations = []

# Apply augmentations to original and contrasted images
images_to_segment = apply_augmentations(image)

for _image in images_to_segment:
_image.show()
segmented_image = segmentation_adapter.fetch_segmented_image(
_image,
min_confidence_threshold=0,
max_iou_threshold=0.05,
)
segmented_image.show()

from PIL import Image

from openadapt import adapters, config, plotting, vision, strategies, utils

def main():
image_file_path = config.ROOT_DIR_PATH / "../tests/assets/excel.png"
image = Image.open(image_file_path)
image.show()
CONTRAST_FACTOR = 1000
image_contrasted = utils.increase_contrast(image, CONTRAST_FACTOR)

segmentation_adapter = adapters.get_default_segmentation_adapter()
all_annotations = []
for _image in (image, image_contrasted):
if 1:
segmented_image = segmentation_adapter.fetch_segmented_image(
_image,
# threshold below which boxes will be filtered out
min_confidence_threshold=0,
# discards all overlapping boxes with IoU > iou_threshold
max_iou_threshold=0.05,
)
segmented_image.show()
else:
annotations = segmentation_adapter.get_annotations(
_image,
# threshold below which boxes will be filtered out
min_confidence_threshold=0,
# discards all overlapping boxes with IoU > iou_threshold
max_iou_threshold=0.1,
)
all_annotations += annotations
import ipdb; ipdb.set_trace()
masks = [annotation.masks.numpy() for annotation in all_annotations]
import ipdb; ipdb.set_trace()

masks = vision.get_masks_from_segmented_image(segmented_image)
plotting.display_binary_images_grid(masks)

refined_masks = vision.refine_masks(masks)
plotting.display_binary_images_grid(refined_masks)

masked_images = vision.extract_masked_images(image, refined_masks)

similar_idx_groups, ungrouped_idxs, _, _ = vision.get_similar_image_idxs(
masked_images,
MIN_SEGMENT_SSIM,
MIN_SEGMENT_SIZE_SIM,
)
ungrouped_masked_images = [
masked_images[idx]
for idx in ungrouped_idxs
]
ungrouped_descriptions = strategies.visual.prompt_for_descriptions(
image,
ungrouped_masked_images,
None,
)

if __name__ == "__main__":
main()
126 changes: 66 additions & 60 deletions openadapt/adapters/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
from ultralytics.models.fastsam import FastSAMPredictor
from ultralytics.models.sam import Predictor as SAMPredictor
import fire
import numpy as np
Expand Down Expand Up @@ -70,10 +70,12 @@ def fetch_segmented_image(
return do_sam(image, model_name, **kwargs)


# TODO: support SAM models
# TODO: consolidate with do_fastsam
@cache.cache()
def do_fastsam(
def get_annotations(
image: Image,
model_name: str,
model_name: str = FASTSAM_MODEL_NAMES[0],
# TODO: inject from config
device: str = "cpu",
retina_masks: bool = True,
Expand All @@ -82,10 +84,12 @@ def do_fastsam(
min_confidence_threshold: float = 0.4,
# discards all overlapping boxes with IoU > iou_threshold
max_iou_threshold: float = 0.9,
# The maximum number of boxes to keep after NMS.
max_det: int = 1000,
max_retries: int = 5,
retry_delay_seconds: float = 0.1,
) -> Image:
"""Get segmented image via FastSAM.
"""Get mask segments via FastSAM.

For usage of thresholds see:
github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf
Expand All @@ -105,77 +109,79 @@ def do_fastsam(

imgsz = imgsz or image.size

# Run inference on image
everything_results = model(
image,
device=device,
retina_masks=retina_masks,
imgsz=imgsz,
conf=min_confidence_threshold,
iou=max_iou_threshold,
max_det=max_det,
)
assert len(everything_results) == 1, len(everything_results)
return everything_results

# Prepare a Prompt Process object
prompt_process = FastSAMPrompt(image, everything_results, device="cpu")

# Everything prompt
annotations = prompt_process.everything_prompt()

# TODO: support other modes once issues are fixed
# https://github.com/ultralytics/ultralytics/issues/13218#issuecomment-2142960103

# Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
# annotations = prompt_process.box_prompt(bbox=[200, 200, 300, 300])

# Text prompt
# annotations = prompt_process.text_prompt(text='a photo of a dog')
@cache.cache()
def do_fastsam(
image: Image,
model_name: str,
# TODO: inject from config
device: str = "cpu",
retina_masks: bool = True,
imgsz: int | tuple[int, int] | None = 1024,
# threshold below which boxes will be filtered out
min_confidence_threshold: float = 0.4,
# discards all overlapping boxes with IoU > iou_threshold
max_iou_threshold: float = 0.9,
# The maximum number of boxes to keep after NMS.
max_det: int = 1000,
max_retries: int = 5,
retry_delay_seconds: float = 0.1,
) -> Image:
"""Get segmented image via FastSAM.

# Point prompt
# points default [[0,0]] [[x1,y1],[x2,y2]]
# point_label default [0] [1,0] 0:background, 1:foreground
# annotations = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])
For usage of thresholds see:
github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf
/ultralytics/utils/ops.py#L164

assert len(annotations) == 1, len(annotations)
annotation = annotations[0]
Args:
TODO
min_confidence_threshold (float, optional): The minimum confidence score
that a detection must meet or exceed to be considered valid. Detections
below this threshold will not be marked. Defaults to 0.00.
max_iou_threshold (float, optional): The maximum allowed Intersection over
Union (IoU) value for overlapping detections. Detections that exceed this
IoU threshold are considered for suppression, keeping only the
detection with the highest confidence. Defaults to 0.05.
"""
model = FastSAM(model_name)

# hide original image
annotation.orig_img = np.ones(annotation.orig_img.shape)
imgsz = imgsz or image.size

# TODO: in memory, e.g. with prompt_process.fast_show_mask()
with TemporaryDirectory() as tmp_dir:
# Force the output format to PNG to prevent JPEG compression artefacts
annotation.path = annotation.path.replace(".jpg", ".png")
prompt_process.plot(
[annotation],
tmp_dir,
with_contours=False,
retina=False,
everything_results = model(
image,
device=device,
retina_masks=retina_masks,
imgsz=imgsz,
conf=min_confidence_threshold,
iou=max_iou_threshold,
max_det=max_det,
)
assert len(everything_results) == 1, len(everything_results)
annotation = everything_results[0]

segmented_image = Image.fromarray(
annotation.plot(
img=np.ones(annotation.orig_img.shape, dtype=annotation.orig_img.dtype),
kpt_line=False,
labels=False,
boxes=False,
probs=False,
color_mode='instance',
)
result_name = os.path.basename(annotation.path)
logger.info(f"{annotation.path=}")
segmented_image_path = Path(tmp_dir) / result_name
segmented_image = Image.open(segmented_image_path)

# Ensure the image is fully loaded before deletion to avoid errors or incomplete operations,
# as some operating systems and file systems lock files during read or processing.
segmented_image.load()

# Attempt to delete the file with retries and delay
retries = 0

while retries < max_retries:
try:
os.remove(segmented_image_path)
break # If deletion succeeds, exit loop
except OSError as e:
if e.errno == errno.ENOENT: # File not found
break
else:
retries += 1
time.sleep(retry_delay_seconds)

if retries == max_retries:
logger.warning(f"Failed to delete {segmented_image_path}")
)

# Check if the dimensions of the original and segmented images differ
# XXX TODO this is a hack, this plotting code should be refactored, but the
# bug may exist in ultralytics, since they seem to resize as well; see:
Expand Down
Loading
Loading