diff --git a/experiments/fastsamsom.py b/experiments/fastsamsom.py index 10e11c69f..b679963ce 100644 --- a/experiments/fastsamsom.py +++ b/experiments/fastsamsom.py @@ -27,9 +27,9 @@ def main() -> None: segmented_image = segmentation_adapter.fetch_segmented_image( image, # threshold below which boxes will be filtered out - conf=0, + min_confidence_threshold=0, # discards all overlapping boxes with IoU > iou_threshold - iou=0.05, + max_iou_threshold=0.05, ) if DEBUG: segmented_image.show() diff --git a/experiments/gridcontext.py b/experiments/gridcontext.py new file mode 100644 index 000000000..3692a0d2f --- /dev/null +++ b/experiments/gridcontext.py @@ -0,0 +1,152 @@ +from PIL import Image, ImageEnhance, ImageOps, ImageFilter +import cv2 +import numpy as np +from openadapt import adapters, config, utils + +def apply_augmentations(image: Image.Image) -> list[Image.Image]: + """ + Applies a series of augmentations to the image and returns a list of augmented images. + + Args: + image (Image.Image): The original image. + + Returns: + list[Image.Image]: List of augmented images. + """ + augmented_images = [] + + # Original image + augmented_images.append(image) + + """ + # increase contrast + enhancer = ImageEnhance.Contrast(image) + enhanced_image = enhancer.enhance(1000) + + # Increase sharpness + sharpness_enhancer = ImageEnhance.Sharpness(image) + augmented_images.append(sharpness_enhancer.enhance(100)) # Sharper image + + # Adaptive Histogram Equalization (CLAHE) + clahe_image = apply_clahe(image) + augmented_images.append(clahe_image) + + # Edge Enhancement + edge_enhanced_image = image.filter(ImageFilter.EDGE_ENHANCE) + augmented_images.append(edge_enhanced_image) + """ + + # invert + inverted_image = ImageOps.invert(image.convert("RGB")) + augmented_images.append(inverted_image) + + enhancer = ImageEnhance.Contrast(inverted_image) + enhanced_image = enhancer.enhance(1000) + augmented_images.append(enhanced_image) + + return augmented_images + +def apply_clahe(image: Image.Image) -> Image.Image: + """ + Applies CLAHE (adaptive histogram equalization) to the image. + + Args: + image (Image.Image): The original image. + + Returns: + Image.Image: Image after applying CLAHE. + """ + # Convert to OpenCV format (BGR) + image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + + # Convert to grayscale + gray_image = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY) + + # Apply CLAHE + clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) + equalized_image = clahe.apply(gray_image) + + # Convert back to RGB and PIL format + equalized_image = cv2.cvtColor(equalized_image, cv2.COLOR_GRAY2RGB) + return Image.fromarray(equalized_image) + +def _main(): + image_file_path = config.ROOT_DIR_PATH / "../tests/assets/excel.png" + image = Image.open(image_file_path) + + segmentation_adapter = adapters.get_default_segmentation_adapter() + all_annotations = [] + + # Apply augmentations to original and contrasted images + images_to_segment = apply_augmentations(image) + + for _image in images_to_segment: + _image.show() + segmented_image = segmentation_adapter.fetch_segmented_image( + _image, + min_confidence_threshold=0, + max_iou_threshold=0.05, + ) + segmented_image.show() + +from PIL import Image + +from openadapt import adapters, config, plotting, vision, strategies, utils + +def main(): + image_file_path = config.ROOT_DIR_PATH / "../tests/assets/excel.png" + image = Image.open(image_file_path) + image.show() + CONTRAST_FACTOR = 1000 + image_contrasted = utils.increase_contrast(image, CONTRAST_FACTOR) + + segmentation_adapter = adapters.get_default_segmentation_adapter() + all_annotations = [] + for _image in (image, image_contrasted): + if 1: + segmented_image = segmentation_adapter.fetch_segmented_image( + _image, + # threshold below which boxes will be filtered out + min_confidence_threshold=0, + # discards all overlapping boxes with IoU > iou_threshold + max_iou_threshold=0.05, + ) + segmented_image.show() + else: + annotations = segmentation_adapter.get_annotations( + _image, + # threshold below which boxes will be filtered out + min_confidence_threshold=0, + # discards all overlapping boxes with IoU > iou_threshold + max_iou_threshold=0.1, + ) + all_annotations += annotations + import ipdb; ipdb.set_trace() + masks = [annotation.masks.numpy() for annotation in all_annotations] + import ipdb; ipdb.set_trace() + + masks = vision.get_masks_from_segmented_image(segmented_image) + plotting.display_binary_images_grid(masks) + + refined_masks = vision.refine_masks(masks) + plotting.display_binary_images_grid(refined_masks) + + masked_images = vision.extract_masked_images(image, refined_masks) + + similar_idx_groups, ungrouped_idxs, _, _ = vision.get_similar_image_idxs( + masked_images, + MIN_SEGMENT_SSIM, + MIN_SEGMENT_SIZE_SIM, + ) + ungrouped_masked_images = [ + masked_images[idx] + for idx in ungrouped_idxs + ] + ungrouped_descriptions = strategies.visual.prompt_for_descriptions( + image, + ungrouped_masked_images, + None, + ) + +if __name__ == "__main__": + main() diff --git a/openadapt/adapters/ultralytics.py b/openadapt/adapters/ultralytics.py index d6f236421..9be6f4585 100644 --- a/openadapt/adapters/ultralytics.py +++ b/openadapt/adapters/ultralytics.py @@ -26,7 +26,7 @@ from ultralytics import FastSAM -from ultralytics.models.fastsam import FastSAMPrompt +from ultralytics.models.fastsam import FastSAMPredictor from ultralytics.models.sam import Predictor as SAMPredictor import fire import numpy as np @@ -70,10 +70,12 @@ def fetch_segmented_image( return do_sam(image, model_name, **kwargs) +# TODO: support SAM models +# TODO: consolidate with do_fastsam @cache.cache() -def do_fastsam( +def get_annotations( image: Image, - model_name: str, + model_name: str = FASTSAM_MODEL_NAMES[0], # TODO: inject from config device: str = "cpu", retina_masks: bool = True, @@ -82,10 +84,12 @@ def do_fastsam( min_confidence_threshold: float = 0.4, # discards all overlapping boxes with IoU > iou_threshold max_iou_threshold: float = 0.9, + # The maximum number of boxes to keep after NMS. + max_det: int = 1000, max_retries: int = 5, retry_delay_seconds: float = 0.1, ) -> Image: - """Get segmented image via FastSAM. + """Get mask segments via FastSAM. For usage of thresholds see: github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf @@ -105,7 +109,6 @@ def do_fastsam( imgsz = imgsz or image.size - # Run inference on image everything_results = model( image, device=device, @@ -113,69 +116,72 @@ def do_fastsam( imgsz=imgsz, conf=min_confidence_threshold, iou=max_iou_threshold, + max_det=max_det, ) + assert len(everything_results) == 1, len(everything_results) + return everything_results - # Prepare a Prompt Process object - prompt_process = FastSAMPrompt(image, everything_results, device="cpu") - - # Everything prompt - annotations = prompt_process.everything_prompt() - - # TODO: support other modes once issues are fixed - # https://github.com/ultralytics/ultralytics/issues/13218#issuecomment-2142960103 - - # Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2] - # annotations = prompt_process.box_prompt(bbox=[200, 200, 300, 300]) - # Text prompt - # annotations = prompt_process.text_prompt(text='a photo of a dog') +@cache.cache() +def do_fastsam( + image: Image, + model_name: str, + # TODO: inject from config + device: str = "cpu", + retina_masks: bool = True, + imgsz: int | tuple[int, int] | None = 1024, + # threshold below which boxes will be filtered out + min_confidence_threshold: float = 0.4, + # discards all overlapping boxes with IoU > iou_threshold + max_iou_threshold: float = 0.9, + # The maximum number of boxes to keep after NMS. + max_det: int = 1000, + max_retries: int = 5, + retry_delay_seconds: float = 0.1, +) -> Image: + """Get segmented image via FastSAM. - # Point prompt - # points default [[0,0]] [[x1,y1],[x2,y2]] - # point_label default [0] [1,0] 0:background, 1:foreground - # annotations = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1]) + For usage of thresholds see: + github.com/ultralytics/ultralytics/blob/dacbd48fcf8407098166c6812eeb751deaac0faf + /ultralytics/utils/ops.py#L164 - assert len(annotations) == 1, len(annotations) - annotation = annotations[0] + Args: + TODO + min_confidence_threshold (float, optional): The minimum confidence score + that a detection must meet or exceed to be considered valid. Detections + below this threshold will not be marked. Defaults to 0.00. + max_iou_threshold (float, optional): The maximum allowed Intersection over + Union (IoU) value for overlapping detections. Detections that exceed this + IoU threshold are considered for suppression, keeping only the + detection with the highest confidence. Defaults to 0.05. + """ + model = FastSAM(model_name) - # hide original image - annotation.orig_img = np.ones(annotation.orig_img.shape) + imgsz = imgsz or image.size - # TODO: in memory, e.g. with prompt_process.fast_show_mask() - with TemporaryDirectory() as tmp_dir: - # Force the output format to PNG to prevent JPEG compression artefacts - annotation.path = annotation.path.replace(".jpg", ".png") - prompt_process.plot( - [annotation], - tmp_dir, - with_contours=False, - retina=False, + everything_results = model( + image, + device=device, + retina_masks=retina_masks, + imgsz=imgsz, + conf=min_confidence_threshold, + iou=max_iou_threshold, + max_det=max_det, + ) + assert len(everything_results) == 1, len(everything_results) + annotation = everything_results[0] + + segmented_image = Image.fromarray( + annotation.plot( + img=np.ones(annotation.orig_img.shape, dtype=annotation.orig_img.dtype), + kpt_line=False, + labels=False, + boxes=False, + probs=False, + color_mode='instance', ) - result_name = os.path.basename(annotation.path) - logger.info(f"{annotation.path=}") - segmented_image_path = Path(tmp_dir) / result_name - segmented_image = Image.open(segmented_image_path) - - # Ensure the image is fully loaded before deletion to avoid errors or incomplete operations, - # as some operating systems and file systems lock files during read or processing. - segmented_image.load() - - # Attempt to delete the file with retries and delay - retries = 0 - - while retries < max_retries: - try: - os.remove(segmented_image_path) - break # If deletion succeeds, exit loop - except OSError as e: - if e.errno == errno.ENOENT: # File not found - break - else: - retries += 1 - time.sleep(retry_delay_seconds) - - if retries == max_retries: - logger.warning(f"Failed to delete {segmented_image_path}") + ) + # Check if the dimensions of the original and segmented images differ # XXX TODO this is a hack, this plotting code should be refactored, but the # bug may exist in ultralytics, since they seem to resize as well; see: diff --git a/poetry.lock b/poetry.lock index ac5e3e77f..d5e2ad833 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7188,19 +7188,6 @@ mxnet = ["mxnet (>=1.5.1,<1.6.0)"] tensorflow = ["tensorflow (>=2.0.0,<2.6.0)"] torch = ["torch (>=1.6.0)"] -[[package]] -name = "thop" -version = "0.1.1.post2209072238" -description = "A tool to count the FLOPs of PyTorch model." -optional = false -python-versions = "*" -files = [ - {file = "thop-0.1.1.post2209072238-py3-none-any.whl", hash = "sha256:01473c225231927d2ad718351f78ebf7cffe6af3bed464c4f1ba1ef0f7cdda27"}, -] - -[package.dependencies] -torch = "*" - [[package]] name = "threadpoolctl" version = "3.4.0" @@ -7674,17 +7661,18 @@ files = [ [[package]] name = "ultralytics" -version = "8.1.47" +version = "8.2.79" description = "Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification." optional = false python-versions = ">=3.8" files = [ - {file = "ultralytics-8.1.47-py3-none-any.whl", hash = "sha256:0c3c2fba4b6758f037c48ac812b8239276c7d9d2863fd5328c070499aedc1fee"}, - {file = "ultralytics-8.1.47.tar.gz", hash = "sha256:273402e2de47e2b18ff8bde19a07cd47e19f1894dc0faa340cfdb50eb0a69ed7"}, + {file = "ultralytics-8.2.79-py3-none-any.whl", hash = "sha256:ae506cc8d1b473d6eb4be04a6080a3ab0a159b47ee76c794a16e6561f8d9afb5"}, + {file = "ultralytics-8.2.79.tar.gz", hash = "sha256:3dda7a8ed246c11019214dc4eb0b3afbb22f0b6eaf1b69d48703e4865e7acd9e"}, ] [package.dependencies] matplotlib = ">=3.3.0" +numpy = ">=1.23.0,<2.0.0" opencv-python = ">=4.6.0" pandas = ">=1.1.4" pillow = ">=7.1.2" @@ -7694,18 +7682,33 @@ pyyaml = ">=5.3.1" requests = ">=2.23.0" scipy = ">=1.4.1" seaborn = ">=0.11.0" -thop = ">=0.1.1" torch = ">=1.8.0" torchvision = ">=0.9.0" tqdm = ">=4.64.0" +ultralytics-thop = ">=2.0.0" [package.extras] -dev = ["check-manifest", "coverage[toml]", "ipython", "mkdocs-jupyter", "mkdocs-material (>=9.5.9)", "mkdocs-redirects", "mkdocs-ultralytics-plugin (>=0.0.44)", "mkdocstrings[python]", "pre-commit", "pytest", "pytest-cov"] +dev = ["coverage[toml]", "ipython", "mkdocs (>=1.6.0)", "mkdocs-jupyter", "mkdocs-macros-plugin (>=1.0.5)", "mkdocs-material (>=9.5.9)", "mkdocs-redirects", "mkdocs-ultralytics-plugin (>=0.1.2)", "mkdocstrings[python]", "pytest", "pytest-cov"] explorer = ["duckdb (<=0.9.2)", "lancedb", "streamlit"] -export = ["coremltools (>=7.0)", "h5py (!=3.11.0)", "numpy (==1.23.5)", "onnx (>=1.12.0)", "openvino (>=2024.0.0)", "tensorflow (<=2.13.1)", "tensorflowjs (>=3.9.0)"] -extra = ["albumentations (>=1.0.3)", "hub-sdk (>=0.0.5)", "ipython", "pycocotools (>=2.0.7)"] +export = ["coremltools (>=7.0)", "flatbuffers (>=23.5.26,<100)", "h5py (!=3.11.0)", "keras", "numpy (==1.23.5)", "onnx (>=1.12.0)", "openvino (>=2024.0.0)", "tensorflow (>=2.0.0)", "tensorflowjs (>=3.9.0)", "tensorstore (>=0.1.63)"] +extra = ["albumentations (>=1.4.6)", "hub-sdk (>=0.0.8)", "ipython", "pycocotools (>=2.0.7)"] logging = ["comet", "dvclive (>=2.12.0)", "tensorboard (>=2.13.0)"] +[[package]] +name = "ultralytics-thop" +version = "2.0.5" +description = "Ultralytics THOP package for fast computation of PyTorch model FLOPs and parameters." +optional = false +python-versions = ">=3.8" +files = [ + {file = "ultralytics_thop-2.0.5-py3-none-any.whl", hash = "sha256:1ee1ad9df2b5bb672da4d1f21c31473a1dc8c0fc859200148951747c059972e5"}, + {file = "ultralytics_thop-2.0.5.tar.gz", hash = "sha256:10f8b590fc992a79ab8ae780a71fbe997a01007aa12c2d55033a00dd8f19585c"}, +] + +[package.dependencies] +numpy = "*" +torch = "*" + [[package]] name = "uritemplate" version = "4.1.1"