diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..069a6b1
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,400 @@
+
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the βLicensor.β The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
diff --git a/NOTICE b/NOTICE
new file mode 100644
index 0000000..17f0a5c
--- /dev/null
+++ b/NOTICE
@@ -0,0 +1,30 @@
+ZIM
+Copyright (c) 2024-present NAVER Cloud Corp.
+
+Creative Commons Attribution-NonCommercial 4.0 International
+
+A summary of the CC BY-NC 4.0 license is located here:
+ https://creativecommons.org/licenses/by-nc/4.0/
+
+This project contains subcomponents with separate copyright notices and license terms.
+Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
+
+=====
+
+facebookresearch/segment-anything
+https://github.com/facebookresearch/segment-anything
+
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+=====
diff --git a/README.md b/README.md
new file mode 100755
index 0000000..6df2161
--- /dev/null
+++ b/README.md
@@ -0,0 +1,32 @@
+# ZIM
+
+**ZIM: Zero-Shot Image Matting for Anything**
+[Beomyoung Kim](https://beomyoung-kim.github.io/), Chanyong Shin, Joonhyun Jeong, Hyungsik Jung, Se-Yun Lee, Sewhan Chun, Dong-Hyun Hwang, Joonsang Yu
+
+NAVER Cloud, ImageVision
+
+[![Paper](https://img.shields.io/badge/Paper-arxiv)](https://arxiv.org)
+[![Page](https://img.shields.io/badge/Project_page-blue)](https://naver-ai.github.io/ZIM)
+[![Demo](https://img.shields.io/badge/Demo-yellow)](https://huggingface.co/spaces/naver-iv/ZIM_Zero-Shot-Image-Matting)
+[![Data](https://img.shields.io/badge/Data-gray)](https://huggingface.co/datasets/naver-iv/MicroMat-3K)
+
+
+
+## Introduction
+
+In this paper, we introduce a novel zero-shot image matting model. Recent models like SAM (Segment Anything Model) exhibit strong zero-shot capabilities, but they fall short in generating fine-grained, high-precision masks. To address this limitation, we propose two key contributions: First, we develop a label converter that transforms segmentation labels into detailed matte labels, creating the new SA1B-Matte dataset. This enables the model to generate high-quality, micro-level matte masks without costly manual annotations. Second, we design a zero-shot matting model equipped with a hierarchical pixel decoder and prompt-aware masked attention mechanism, improving both the resolution of mask outputs and the modelβs ability to focus on specific regions based on user prompts. We evaluate our model using the newly introduced ZIM test set, which contains high-quality micro-level matte labels. Experimental results show that our model outperforms SAM and other existing methods in precision and zero-shot generalization. Furthermore, we demonstrate the versatility of our approach in downstream tasks, including image inpainting and 3D neural radiance fields (NeRF), where the ability to produce precise matte masks is crucial. Our contributions provide a robust foundation for advancing zero-shot image matting and its applications across a wide range of computer vision tasks.
+
+
+## Updates
+**Available Soon**
+
+
+## Installation
+
+Our implementation is based on [SAM](https://github.com/facebookresearch/segment-anything).
+
+Please check the [installation instructions](INSTALL.md)
+
+## License
+
+Available Soon
\ No newline at end of file
diff --git a/config/__init__.py b/config/__init__.py
new file mode 100755
index 0000000..3396bb7
--- /dev/null
+++ b/config/__init__.py
@@ -0,0 +1 @@
+from config.config import generate_config
diff --git a/config/config.py b/config/config.py
new file mode 100755
index 0000000..6510baf
--- /dev/null
+++ b/config/config.py
@@ -0,0 +1,66 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+from easydict import EasyDict as edict
+
+config_ = edict()
+
+"""
+ Common configs
+"""
+config_.data_root = "/mnt/tmp"
+config_.use_ddp = True
+config_.use_amp = False
+config_.local_rank = 0
+config_.world_size = 1
+config_.random_seed = 3407
+"""
+ Network configs
+"""
+config_.network = edict()
+config_.network.encoder = "vit_b"
+config_.network.decoder = "zim"
+config_.network.encode_kernel = 21
+"""
+ Evaluation configs
+"""
+config_.eval = edict()
+config_.eval.workers = 4
+config_.eval.image_size = 1024
+config_.eval.prompt_type = "point,bbox"
+config_.eval.model_list = "zim,sam"
+config_.eval.zim_weights = ""
+config_.eval.sam_weights = ""
+"""
+ Dataset configs
+"""
+config_.dataset = edict()
+config_.dataset.valset = "MicroMat3K"
+config_.dataset.data_type = "fine,coarse"
+config_.dataset.data_list_txt = "data_list.txt"
+
+
+def remove_prefix(text, prefix):
+ if text.startswith(prefix):
+ return text[len(prefix) :]
+ return text
+
+
+def generate_config(args):
+ # merge args & config
+ for k, v in args.items():
+ if k.startswith("network_"):
+ config_["network"][remove_prefix(k, "network_")] = v
+ elif k.startswith("eval_"):
+ config_["eval"][remove_prefix(k, "eval_")] = v
+ elif k.startswith("dataset_"):
+ config_["dataset"][remove_prefix(k, "dataset_")] = v
+ elif k == "amp":
+ config_["use_amp"] = v
+ else:
+ config_[k] = v
+ return config_
diff --git a/demo/examples/examples_example1.jpg b/demo/examples/examples_example1.jpg
new file mode 100755
index 0000000..940be02
Binary files /dev/null and b/demo/examples/examples_example1.jpg differ
diff --git a/demo/examples/examples_example2.jpg b/demo/examples/examples_example2.jpg
new file mode 100755
index 0000000..6bddf7a
Binary files /dev/null and b/demo/examples/examples_example2.jpg differ
diff --git a/demo/examples/examples_example3.jpg b/demo/examples/examples_example3.jpg
new file mode 100755
index 0000000..2a9a623
Binary files /dev/null and b/demo/examples/examples_example3.jpg differ
diff --git a/demo/examples/examples_example4.jpg b/demo/examples/examples_example4.jpg
new file mode 100755
index 0000000..3ff6740
Binary files /dev/null and b/demo/examples/examples_example4.jpg differ
diff --git a/demo/examples/examples_example5.jpg b/demo/examples/examples_example5.jpg
new file mode 100755
index 0000000..7fb28c9
Binary files /dev/null and b/demo/examples/examples_example5.jpg differ
diff --git a/demo/examples/examples_example6.jpg b/demo/examples/examples_example6.jpg
new file mode 100755
index 0000000..6a84b26
Binary files /dev/null and b/demo/examples/examples_example6.jpg differ
diff --git a/demo/examples/examples_example7.jpg b/demo/examples/examples_example7.jpg
new file mode 100755
index 0000000..d346c03
Binary files /dev/null and b/demo/examples/examples_example7.jpg differ
diff --git a/demo/examples/examples_example8.jpg b/demo/examples/examples_example8.jpg
new file mode 100755
index 0000000..673ffcd
Binary files /dev/null and b/demo/examples/examples_example8.jpg differ
diff --git a/demo/gradio_demo.py b/demo/gradio_demo.py
new file mode 100755
index 0000000..13dd444
--- /dev/null
+++ b/demo/gradio_demo.py
@@ -0,0 +1,354 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+import os, sys
+sys.path.append(os.getcwd())
+
+import os
+import torch
+import gradio as gr
+from gradio_image_prompter import ImagePrompter
+import numpy as np
+import cv2
+from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
+from zim.utils import show_mat_anns
+
+def get_shortest_axis(image):
+ h, w, _ = image.shape
+ return h if h < w else w
+
+def reset_image(image, prompts):
+ if image is None:
+ image = np.zeros((1024, 1024, 3), dtype=np.uint8)
+ else:
+ image = image['image']
+ predictor.set_image(image)
+ prompts = dict()
+ black = np.zeros(image.shape[:2], dtype=np.uint8)
+
+ return (image, image, image, black, black, prompts)
+
+def reset_example_image(image, prompts):
+ if image is None:
+ image = np.zeros((1024, 1024, 3), dtype=np.uint8)
+
+ predictor.set_image(image)
+ prompts = dict()
+ black = np.zeros(image.shape[:2], dtype=np.uint8)
+
+ image_dict = {}
+ image_dict['image'] = image
+ image_dict['prompts'] = prompts
+
+ return (image, image_dict, image, image, black, black, prompts)
+
+def run_amg(image):
+ masks = mask_generator.generate(image)
+ masks_vis = show_mat_anns(image, masks)
+
+ return masks_vis
+
+def run_model(image, prompts):
+ if not prompts:
+ raise gr.Error(f'Please input any point or BBox')
+ point_coords = None
+ point_labels = None
+ boxes = None
+ zim_mask = None
+
+ if "point" in prompts:
+ point_coords, point_labels = [], []
+
+ for type, pts in prompts["point"]:
+ point_coords.append(pts)
+ point_labels.append(type)
+ point_coords = np.array(point_coords)
+ point_labels = np.array(point_labels)
+
+ if "bbox" in prompts:
+ boxes = prompts['bbox']
+ boxes = np.array(boxes)
+
+ if "scribble" in prompts:
+ point_coords, point_labels = [], []
+
+ for pts in prompts["scribble"]:
+ point_coords.append(np.flip(pts))
+ point_labels.append(1)
+ if len(point_coords) == 0:
+ raise gr.Error("Please input any scribbles.")
+ point_coords = np.array(point_coords)
+ point_labels = np.array(point_labels)
+
+ zim_mask, _, _ = predictor.predict(
+ point_coords=point_coords,
+ point_labels=point_labels,
+ box=boxes,
+ multimask_output=False,
+ )
+ zim_mask = np.squeeze(zim_mask, axis=0)
+ zim_mask = np.uint8(zim_mask * 255)
+
+ return zim_mask
+
+def reset_scribble(image, scribble, prompts):
+ # scribble = dict()
+ for k in prompts.keys():
+ prompts[k] = []
+
+ for k, v in scribble.items():
+ scribble[k] = None
+
+ zim_mask = np.zeros_like(image)
+
+ return scribble, zim_mask
+
+def update_scribble(image, scribble, prompts):
+ if "point" in prompts:
+ del prompts["point"]
+
+ if "bbox" in prompts:
+ del prompts["bbox"]
+
+ prompts = dict() # reset prompt
+ scribble_mask = scribble["layers"][0][..., -1] > 0
+
+ scribble_coords = np.argwhere(scribble_mask)
+ n_points = min(len(scribble_coords), 24)
+ indices = np.linspace(0, len(scribble_coords)-1, n_points, dtype=int)
+ scribble_sampled = scribble_coords[indices]
+
+ prompts["scribble"] = scribble_sampled
+
+ zim_mask = run_model(image, prompts)
+
+ return zim_mask, prompts
+
+
+def draw_point(img, pt, size, color):
+ # draw circle with white boundary region
+ cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 1.3), (255, 255, 255), -1)
+ cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 0.9), color, -1)
+
+
+def draw_images(image, mask, prompts):
+ if len(prompts) == 0 or mask.shape[1] == 1:
+ return image, image, image
+
+ minor = get_shortest_axis(image)
+ size = int(minor / 80)
+
+ image = np.float32(image)
+
+ def blending(image, mask):
+ mask = np.float32(mask) / 255
+ blended_image = np.zeros_like(image, dtype=np.float32)
+ blended_image[:, :, :] = [108, 0, 192]
+ blended_image = (image * 0.5) + (blended_image * 0.5)
+
+ img_with_mask = mask[:, :, None] * blended_image + (1 - mask[:, :, None]) * image
+ img_with_mask = np.uint8(img_with_mask)
+
+ return img_with_mask
+
+ img_with_mask = blending(image, mask)
+ img_with_point = img_with_mask.copy()
+
+ if "point" in prompts:
+ for type, pts in prompts["point"]:
+ if type == "Positive":
+ color = (0, 0, 255)
+ draw_point(img_with_point, pts, size, color)
+ elif type == "Negative":
+ color = (255, 0, 0)
+ draw_point(img_with_point, pts, size, color)
+
+ size = int(minor / 200)
+
+ return (
+ img,
+ img_with_mask,
+ )
+
+def get_point_or_box_prompts(img, prompts):
+ image, img_prompts = img['image'], img['points']
+ point_prompts = []
+ box_prompts = []
+ for prompt in img_prompts:
+ for p in range(len(prompt)):
+ prompt[p] = int(prompt[p])
+ if prompt[2] == 2 and prompt[5] == 3: # box prompt
+ box_prompts = [[prompt[0], prompt[1], prompt[3], prompt[4]], ]
+ elif prompt[2] == 1 and prompt[5] == 4: # Positive point prompt
+ point_prompts.append((1, (prompt[0], prompt[1])))
+ elif prompt[2] == 0 and prompt[5] == 4: # Negative point prompt
+ point_prompts.append((0, (prompt[0], prompt[1])))
+
+ if "scribble" in prompts:
+ del prompts["scribble"]
+
+ if len(point_prompts) > 0:
+ prompts['point'] = point_prompts
+ elif 'point' in prompts:
+ del prompts['point']
+
+ if len(box_prompts) > 0:
+ prompts['bbox'] = box_prompts
+ elif 'bbox' in prompts:
+ del prompts['bbox']
+
+ zim_mask = run_model(image, prompts)
+
+ return image, zim_mask, prompts
+
+def get_examples():
+ assets_dir = os.path.join(os.path.dirname(__file__), 'examples')
+ images = os.listdir(assets_dir)
+ return [os.path.join(assets_dir, img) for img in images]
+
+if __name__ == "__main__":
+
+ backbone = "vit_l"
+ ckpt_p = "results/zim_vit_l_2092"
+
+ model = zim_model_registry[backbone](checkpoint=ckpt_p)
+ if torch.cuda.is_available():
+ model.cuda()
+
+ predictor = ZimPredictor(model)
+ mask_generator = ZimAutomaticMaskGenerator(
+ model,
+ pred_iou_thresh=0.7,
+ points_per_batch=8,
+ stability_score_thresh=0.9,
+ )
+
+ with gr.Blocks() as demo:
+ gr.Markdown("#
[Demo] ZIM: Zero-Shot Image Matting for Anything")
+
+ prompts = gr.State(dict())
+ img = gr.Image(visible=False)
+ example_image = gr.Image(visible=False)
+
+ with gr.Row():
+ with gr.Column():
+ # Point and Bbox prompt
+ with gr.Tab(label="Point or Box"):
+ img_with_point_or_box = ImagePrompter(
+ label="query image",
+ sources="upload"
+ )
+ interactions = "Left Click (Pos) | Middle/Right Click (Neg) | Press Move (Box)"
+ gr.Markdown("[π±οΈ] π {} π
".format(interactions))
+ run_bttn = gr.Button("Run")
+ amg_bttn = gr.Button("Automatic Mask Generation")
+
+ # Scribble prompt
+ with gr.Tab(label="Scribble"):
+ img_with_scribble = gr.ImageEditor(
+ label="Scribble",
+ brush=gr.Brush(colors=["#00FF00"], default_size=15),
+ sources="upload",
+ transforms=None,
+ layers=False
+ )
+ interactions = "Press Move (Scribble)"
+ gr.Markdown(" Step 1. Select Draw button
")
+ gr.Markdown(" Step 2. π {} π
".format(interactions))
+ scribble_bttn = gr.Button("Run")
+ scribble_reset_bttn = gr.Button("Reset Scribbles")
+ amg_scribble_bttn = gr.Button("Automatic Mask Generation")
+
+ # Example image
+ gr.Examples(get_examples(), inputs=[example_image])
+
+ # with gr.Row():
+ with gr.Column():
+ with gr.Tab(label="ZIM Image"):
+ img_with_zim_mask = gr.Image(
+ label="ZIM Image",
+ interactive=False
+ )
+
+ with gr.Tab(label="ZIM Mask"):
+ zim_mask = gr.Image(
+ label="ZIM Mask",
+ image_mode="L",
+ interactive=False
+ )
+ with gr.Tab(label="ZIM Auto Mask"):
+ zim_amg = gr.Image(
+ label="ZIM Auto Mask",
+ interactive=False
+ )
+
+ example_image.change(
+ reset_example_image,
+ [example_image, prompts],
+ [
+ img,
+ img_with_point_or_box,
+ img_with_scribble,
+ img_with_zim_mask,
+ zim_amg,
+ zim_mask,
+ prompts,
+ ]
+ )
+
+ img_with_point_or_box.upload(
+ reset_image,
+ [img_with_point_or_box, prompts],
+ [
+ img,
+ img_with_scribble,
+ img_with_zim_mask,
+ zim_amg,
+ zim_mask,
+ prompts,
+ ],
+ )
+
+ amg_bttn.click(
+ run_amg,
+ [img],
+ [zim_amg]
+ )
+ amg_scribble_bttn.click(
+ run_amg,
+ [img],
+ [zim_amg]
+ )
+
+ run_bttn.click(
+ get_point_or_box_prompts,
+ [img_with_point_or_box, prompts],
+ [img, zim_mask, prompts]
+ )
+
+ zim_mask.change(
+ draw_images,
+ [img, zim_mask, prompts],
+ [
+ img, img_with_zim_mask,
+ ],
+ )
+ scribble_reset_bttn.click(
+ reset_scribble,
+ [img, img_with_scribble, prompts],
+ [img_with_scribble, zim_mask],
+ )
+ scribble_bttn.click(
+ update_scribble,
+ [img, img_with_scribble, prompts],
+ [zim_mask, prompts],
+ )
+
+ demo.queue()
+ demo.launch(
+ server_name="0.0.0.0",
+ server_port=11928,
+ )
\ No newline at end of file
diff --git a/demo/gradio_demo_comparison.py b/demo/gradio_demo_comparison.py
new file mode 100755
index 0000000..5986262
--- /dev/null
+++ b/demo/gradio_demo_comparison.py
@@ -0,0 +1,418 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+import os, sys
+sys.path.append(os.getcwd())
+
+# Gradio demo, comparison SAM vs ZIM
+import os
+import torch
+import gradio as gr
+from gradio_image_prompter import ImagePrompter
+import numpy as np
+import cv2
+from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
+from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
+from zim.utils import show_mat_anns
+
+def get_shortest_axis(image):
+ h, w, _ = image.shape
+ return h if h < w else w
+
+def reset_image(image, prompts):
+ if image is None:
+ image = np.zeros((1024, 1024, 3), dtype=np.uint8)
+ else:
+ image = image['image']
+ zim_predictor.set_image(image)
+ sam_predictor.set_image(image)
+ prompts = dict()
+ black = np.zeros(image.shape[:2], dtype=np.uint8)
+
+ return (image, image, image, image, black, black, black, black, prompts)
+
+def reset_example_image(image, prompts):
+ if image is None:
+ image = np.zeros((1024, 1024, 3), dtype=np.uint8)
+
+ zim_predictor.set_image(image)
+ sam_predictor.set_image(image)
+ prompts = dict()
+ black = np.zeros(image.shape[:2], dtype=np.uint8)
+
+ image_dict = {}
+ image_dict['image'] = image
+ image_dict['prompts'] = prompts
+
+ return (image, image_dict, image, image, image, black, black, black, black, prompts)
+
+def run_amg(image):
+ zim_masks = zim_mask_generator.generate(image)
+ zim_masks_vis = show_mat_anns(image, zim_masks)
+
+ sam_masks = sam_mask_generator.generate(image)
+ sam_masks_vis = show_mat_anns(image, sam_masks)
+
+ return zim_masks_vis, sam_masks_vis
+
+
+def run_model(image, prompts):
+ if not prompts:
+ raise gr.Error(f'Please input any point or BBox')
+ point_coords = None
+ point_labels = None
+ boxes = None
+
+ if "point" in prompts:
+ point_coords, point_labels = [], []
+
+ for type, pts in prompts["point"]:
+ point_coords.append(pts)
+ point_labels.append(type)
+ point_coords = np.array(point_coords)
+ point_labels = np.array(point_labels)
+
+ if "bbox" in prompts:
+ boxes = prompts['bbox']
+ boxes = np.array(boxes)
+
+ if "scribble" in prompts:
+ point_coords, point_labels = [], []
+
+ for pts in prompts["scribble"]:
+ point_coords.append(np.flip(pts))
+ point_labels.append(1)
+ if len(point_coords) == 0:
+ raise gr.Error("Please input any scribbles.")
+ point_coords = np.array(point_coords)
+ point_labels = np.array(point_labels)
+
+ # run ZIM
+ zim_mask, _, _ = zim_predictor.predict(
+ point_coords=point_coords,
+ point_labels=point_labels,
+ box=boxes,
+ multimask_output=False,
+ )
+ zim_mask = np.squeeze(zim_mask, axis=0)
+ zim_mask = np.uint8(zim_mask * 255)
+
+ # run SAM
+ sam_mask, _, _ = sam_predictor.predict(
+ point_coords=point_coords,
+ point_labels=point_labels,
+ box=boxes,
+ multimask_output=False,
+ )
+ sam_mask = np.squeeze(sam_mask, axis=0)
+ sam_mask = np.uint8(sam_mask * 255)
+
+ return zim_mask, sam_mask
+
+def reset_scribble(image, scribble, prompts):
+ # scribble = dict()
+ for k in prompts.keys():
+ prompts[k] = []
+
+ for k, v in scribble.items():
+ scribble[k] = None
+
+ black = np.zeros(image.shape[:1], dtype=np.uint8)
+
+ return scribble, black, black
+
+def update_scribble(image, scribble, prompts):
+ if "point" in prompts:
+ del prompts["point"]
+
+ if "bbox" in prompts:
+ del prompts["bbox"]
+
+ prompts = dict() # reset prompt
+ scribble_mask = scribble["layers"][0][..., -1] > 0
+
+ scribble_coords = np.argwhere(scribble_mask)
+ n_points = min(len(scribble_coords), 24)
+ indices = np.linspace(0, len(scribble_coords)-1, n_points, dtype=int)
+ scribble_sampled = scribble_coords[indices]
+
+ prompts["scribble"] = scribble_sampled
+
+ zim_mask, sam_mask = run_model(image, prompts)
+
+ return zim_mask, sam_mask, prompts
+
+
+def draw_point(img, pt, size, color):
+ # draw circle with white boundary region
+ cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 1.3), (255, 255, 255), -1)
+ cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 0.9), color, -1)
+
+
+def draw_images(image, mask, prompts):
+ if len(prompts) == 0 or mask.shape[1] == 1:
+ return image, image, image
+
+ minor = get_shortest_axis(image)
+ size = int(minor / 80)
+
+ image = np.float32(image)
+
+ def blending(image, mask):
+ mask = np.float32(mask) / 255
+ blended_image = np.zeros_like(image, dtype=np.float32)
+ blended_image[:, :, :] = [108, 0, 192]
+ blended_image = (image * 0.5) + (blended_image * 0.5)
+
+ img_with_mask = mask[:, :, None] * blended_image + (1 - mask[:, :, None]) * image
+ img_with_mask = np.uint8(img_with_mask)
+
+ return img_with_mask
+
+ img_with_mask = blending(image, mask)
+ img_with_point = img_with_mask.copy()
+
+ if "point" in prompts:
+ for type, pts in prompts["point"]:
+ if type == "Positive":
+ color = (0, 0, 255)
+ draw_point(img_with_point, pts, size, color)
+ elif type == "Negative":
+ color = (255, 0, 0)
+ draw_point(img_with_point, pts, size, color)
+
+ size = int(minor / 200)
+
+ return (
+ img,
+ img_with_mask,
+ )
+
+def get_point_or_box_prompts(img, prompts):
+ image, img_prompts = img['image'], img['points']
+ point_prompts = []
+ box_prompts = []
+ for prompt in img_prompts:
+ for p in range(len(prompt)):
+ prompt[p] = int(prompt[p])
+ if prompt[2] == 2 and prompt[5] == 3: # box prompt
+ if len(box_prompts) != 0:
+ raise gr.Error("Please input only one BBox.", duration=5)
+ box_prompts.append([prompt[0], prompt[1], prompt[3], prompt[4]])
+ elif prompt[2] == 1 and prompt[5] == 4: # Positive point prompt
+ point_prompts.append((1, (prompt[0], prompt[1])))
+ elif prompt[2] == 0 and prompt[5] == 4: # Negative point prompt
+ point_prompts.append((0, (prompt[0], prompt[1])))
+
+ if "scribble" in prompts:
+ del prompts["scribble"]
+
+ if len(point_prompts) > 0:
+ prompts['point'] = point_prompts
+ elif 'point' in prompts:
+ del prompts['point']
+
+ if len(box_prompts) > 0:
+ prompts['bbox'] = box_prompts
+ elif 'bbox' in prompts:
+ del prompts['bbox']
+
+ zim_mask, sam_mask = run_model(image, prompts)
+
+ return image, zim_mask, sam_mask, prompts
+
+def get_examples():
+ assets_dir = os.path.join(os.path.dirname(__file__), 'examples')
+ images = os.listdir(assets_dir)
+ return [os.path.join(assets_dir, img) for img in images]
+
+if __name__ == "__main__":
+ backbone = "vit_b"
+
+ # load ZIM
+ ckpt_mat = "results/zim_vit_b_2043"
+ zim = zim_model_registry[backbone](checkpoint=ckpt_mat)
+ if torch.cuda.is_available():
+ zim.cuda()
+ zim_predictor = ZimPredictor(zim)
+ zim_mask_generator = ZimAutomaticMaskGenerator(
+ zim,
+ pred_iou_thresh=0.7,
+ points_per_batch=8,
+ stability_score_thresh=0.9,
+ )
+
+ # load SAM
+ ckpt_sam = "results/sam_vit_b_01ec64.pth"
+ sam = sam_model_registry[backbone](checkpoint=ckpt_sam)
+ if torch.cuda.is_available():
+ sam.cuda()
+ sam_predictor = SamPredictor(sam)
+ sam_mask_generator = SamAutomaticMaskGenerator(
+ sam,
+ points_per_batch=8,
+ )
+
+ with gr.Blocks() as demo:
+ gr.Markdown("# [Demo] ZIM: Zero-Shot Image Matting for Anything")
+
+ prompts = gr.State(dict())
+ img = gr.Image(visible=False)
+ example_image = gr.Image(visible=False)
+
+ with gr.Row():
+ with gr.Column():
+ # Point and Bbox prompt
+ with gr.Tab(label="Point or Box"):
+ img_with_point_or_box = ImagePrompter(
+ label="query image",
+ sources="upload"
+ )
+ interactions = "Left Click (Pos) | Middle/Right Click (Neg) | Press Move (Box)"
+ gr.Markdown("[π±οΈ] π {} π
".format(interactions))
+ run_bttn = gr.Button("Run")
+ amg_bttn = gr.Button("Automatic Mask Generation")
+
+ # Scribble prompt
+ with gr.Tab(label="Scribble"):
+ img_with_scribble = gr.ImageEditor(
+ label="Scribble",
+ brush=gr.Brush(colors=["#00FF00"], default_size=40),
+ sources="upload",
+ transforms=None,
+ layers=False
+ )
+ interactions = "Press Move (Scribble)"
+ gr.Markdown(" Step 1. Select Draw button
")
+ gr.Markdown(" Step 2. π {} π
".format(interactions))
+ scribble_bttn = gr.Button("Run")
+ scribble_reset_bttn = gr.Button("Reset Scribbles")
+ amg_scribble_bttn = gr.Button("Automatic Mask Generation")
+
+ # Example image
+ gr.Examples(get_examples(), inputs=[example_image])
+
+ # with gr.Row():
+ with gr.Column():
+ with gr.Tab(label="ZIM Image"):
+ img_with_zim_mask = gr.Image(
+ label="ZIM Image",
+ interactive=False
+ )
+
+ with gr.Tab(label="ZIM Mask"):
+ zim_mask = gr.Image(
+ label="ZIM Mask",
+ image_mode="L",
+ interactive=False
+ )
+ with gr.Tab(label="ZIM Auto Mask"):
+ zim_amg = gr.Image(
+ label="ZIM Auto Mask",
+ interactive=False
+ )
+
+ with gr.Column():
+ with gr.Tab(label="SAM Image"):
+ img_with_sam_mask = gr.Image(
+ label="SAM image",
+ interactive=False
+ )
+
+ with gr.Tab(label="SAM Mask"):
+ sam_mask = gr.Image(
+ label="SAM Mask",
+ image_mode="L",
+ interactive=False
+ )
+
+ with gr.Tab(label="SAM Auto Mask"):
+ sam_amg = gr.Image(
+ label="SAM Auto Mask",
+ interactive=False
+ )
+
+ example_image.change(
+ reset_example_image,
+ [example_image, prompts],
+ [
+ img,
+ img_with_point_or_box,
+ img_with_scribble,
+ img_with_zim_mask,
+ img_with_sam_mask,
+ zim_amg,
+ sam_amg,
+ zim_mask,
+ sam_mask,
+ prompts,
+ ]
+ )
+
+ img_with_point_or_box.upload(
+ reset_image,
+ [img_with_point_or_box, prompts],
+ [
+ img,
+ img_with_scribble,
+ img_with_zim_mask,
+ img_with_sam_mask,
+ zim_amg,
+ sam_amg,
+ zim_mask,
+ sam_mask,
+ prompts,
+ ],
+ )
+
+ amg_bttn.click(
+ run_amg,
+ [img],
+ [zim_amg, sam_amg]
+ )
+ amg_scribble_bttn.click(
+ run_amg,
+ [img],
+ [zim_amg, sam_amg]
+ )
+
+ run_bttn.click(
+ get_point_or_box_prompts,
+ [img_with_point_or_box, prompts],
+ [img, zim_mask, sam_mask, prompts]
+ )
+
+ zim_mask.change(
+ draw_images,
+ [img, zim_mask, prompts],
+ [
+ img, img_with_zim_mask,
+ ],
+ )
+ sam_mask.change(
+ draw_images,
+ [img, sam_mask, prompts],
+ [
+ img, img_with_sam_mask,
+ ],
+ )
+
+ scribble_reset_bttn.click(
+ reset_scribble,
+ [img, img_with_scribble, prompts],
+ [img_with_scribble, zim_mask, sam_mask],
+ )
+ scribble_bttn.click(
+ update_scribble,
+ [img, img_with_scribble, prompts],
+ [zim_mask, sam_mask, prompts],
+ )
+
+ demo.queue()
+ demo.launch(
+ server_name="0.0.0.0",
+ server_port=11928,
+ )
\ No newline at end of file
diff --git a/eval/eval_loader.py b/eval/eval_loader.py
new file mode 100755
index 0000000..78fb760
--- /dev/null
+++ b/eval/eval_loader.py
@@ -0,0 +1,117 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import os
+import numpy as np
+import torch
+import json
+from PIL import Image
+import torch.utils.data as data
+from torch.utils.data.distributed import DistributedSampler
+
+def get_evalloader(config):
+ loader_dict = {}
+
+ for data_type in config.dataset.data_type:
+ dataset = Dataset(
+ config.data_root,
+ config.dataset,
+ data_type,
+ )
+
+ if config.local_rank == 0:
+ print(f"LOG) ZIM Dataset: {data_type} ({len(dataset)})")
+
+ sampler = None
+
+ if config.use_ddp:
+ sampler = DistributedSampler(
+ dataset,
+ rank=config.local_rank,
+ num_replicas=config.world_size,
+ )
+
+ dataloader = data.DataLoader(
+ dataset,
+ batch_size=1,
+ num_workers=config.eval.workers,
+ sampler=sampler,
+ shuffle=False,
+ pin_memory=True,
+ drop_last=False,
+ )
+ loader_dict[data_type] = dataloader
+
+ return loader_dict
+
+
+class Dataset(data.Dataset):
+ def __init__(
+ self,
+ data_root,
+ dataset_config,
+ data_type,
+ ):
+ super(Dataset, self).__init__()
+ self.root = os.path.join(data_root, dataset_config.valset)
+
+ with open(os.path.join(self.root, dataset_config.data_list_txt), "r") as f:
+ f_list = f.read().splitlines()
+ f_list = [p for p in f_list if data_type in p]
+
+ self.images = []
+ self.mattes = []
+ self.jsons = []
+
+ for fname in f_list:
+ img_path, matte_path, json_path, seg_path = fname.split(" ")
+
+ img_path = os.path.join(self.root, img_path)
+ matte_path = os.path.join(self.root, matte_path)
+ json_path = os.path.join(self.root, json_path)
+
+ self.images.append(img_path)
+ self.mattes.append(matte_path)
+ self.jsons.append(json_path)
+
+ assert len(self.images) == len(self.mattes) == len(self.jsons)
+
+
+ def __getitem__(self, index):
+ fname = os.path.basename(self.mattes[index])
+
+ img = Image.open(self.images[index]).convert('RGB')
+ matte = Image.open(self.mattes[index]).convert('L')
+ orig_w, orig_h = img.size
+
+ img = np.float32(img)
+ matte = np.float32(matte) / 255.
+
+ ratio = (matte > 0.3).sum() / matte.size
+
+ with open(self.jsons[index], "r") as f:
+ meta_data = json.load(f)
+
+ points = meta_data["point"]
+ points += [(-1, -1, -1) for _ in range(50-len(points))] # padding
+
+ bbox = meta_data["bbox"]
+
+ output = {
+ "images": torch.tensor(img, dtype=torch.float),
+ "mattes": torch.tensor(matte, dtype=torch.float),
+ "points": torch.tensor(points, dtype=torch.float),
+ "bboxes": torch.tensor(bbox, dtype=torch.float),
+ "fname": fname,
+ "ratio": ratio,
+ }
+
+ return output
+
+ def __len__(self):
+ return len(self.images)
+
diff --git a/eval/evaluator.py b/eval/evaluator.py
new file mode 100755
index 0000000..9aae38c
--- /dev/null
+++ b/eval/evaluator.py
@@ -0,0 +1,146 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import torch
+from torch import nn
+from typing import Any, Dict, List, Tuple
+import numpy as np
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from zim.build_model import build_zim_model
+from zim.predictor import ZimPredictor
+from segment_anything import SamPredictor, sam_model_registry
+
+def load_sam_evaluator(config, device):
+ sam = sam_model_registry[config.network.encoder](checkpoint=config.eval.sam_weights).cuda(device)
+ sam_evaluator = SamEvaluator(sam, config.eval.prompt_type)
+ if config.use_ddp:
+ sam_evaluator = DDP(
+ sam_evaluator,
+ device_ids=[config.local_rank],
+ output_device=config.local_rank,
+ )
+ sam_evaluator.eval()
+ return sam_evaluator
+
+def load_zim_evaluator(config, device):
+ zim = build_zim_model(config.eval.zim_weights).cuda(device)
+ zim_evaluator = ZimEvaluator(zim, config.eval.prompt_type)
+
+ return zim_evaluator
+
+class SamEvaluator(SamPredictor, nn.Module):
+ def __init__(
+ self,
+ sam_model,
+ prompt_type: List[str] = None
+ ):
+ super().__init__(sam_model=sam_model)
+ self.prompt_type = prompt_type
+
+ def forward(self, batched_input, multimask_output: bool = False):
+ input_images = batched_input["images"]
+
+ outputs = {
+ prompt: {
+ "masks": [],
+ } for prompt in self.prompt_type
+ }
+
+ with torch.inference_mode():
+ for idx, input_image in enumerate(input_images):
+ input_image = input_image.cpu().numpy().astype(np.uint8)
+ self.set_image(image=input_image)
+
+ for prompt in self.prompt_type:
+ point_coords = None
+ point_labels = None
+ bbox = None
+
+ if prompt == "point":
+ points = batched_input["points"][idx]
+ points = points[points[:, 2] >= 0] # remove points whose label=-1
+ point_coords = points[:, :2].cpu().numpy()
+ point_labels = points[:, 2].cpu().numpy()
+
+ elif prompt == "bbox":
+ bbox = batched_input["bboxes"][idx]
+ bbox = bbox.unsqueeze(0).cpu().numpy()
+
+ masks, _, _ = self.predict(
+ point_coords=point_coords,
+ point_labels=point_labels,
+ box=bbox,
+ multimask_output=False,
+ )
+ masks = torch.from_numpy(masks).float().unsqueeze(0).to(self.device)
+
+ outputs[prompt]["masks"].append(masks)
+
+ # Concat through batch dimension
+ for prompt in self.prompt_type:
+ for k, v in outputs[prompt].items():
+ if len(v) > 0:
+ outputs[prompt][k] = torch.cat(v, dim=0)
+
+ return outputs
+
+
+class ZimEvaluator(ZimPredictor, nn.Module):
+ def __init__(
+ self,
+ model,
+ prompt_type: List[str] = None
+ ) -> None:
+ super().__init__(model=model)
+ self.prompt_type = prompt_type
+
+ def forward(self, batched_input, multimask_output: bool = False):
+ input_images = batched_input["images"]
+
+ outputs = {
+ prompt: {
+ "masks": [],
+ } for prompt in self.prompt_type
+ }
+
+ with torch.inference_mode():
+ for idx, input_image in enumerate(input_images):
+ input_image = input_image.cpu().numpy().astype(np.uint8)
+ self.set_image(image=input_image)
+
+ for prompt in self.prompt_type:
+ point_coords = None
+ point_labels = None
+ bbox = None
+
+ if prompt == "point":
+ points = batched_input["points"][idx]
+ points = points[points[:, 2] >= 0] # remove points whose label=-1
+ point_coords = points[:, :2].cpu().numpy()
+ point_labels = points[:, 2].cpu().numpy()
+
+ elif prompt == "bbox":
+ bbox = batched_input["bboxes"][idx]
+ bbox = bbox.cpu().numpy()
+
+ masks, _, _ = self.predict(
+ point_coords=point_coords,
+ point_labels=point_labels,
+ box=bbox,
+ multimask_output=False,
+ )
+ masks = torch.from_numpy(masks).float().unsqueeze(0).to(self.device)
+
+ outputs[prompt]["masks"].append(masks)
+
+ for prompt in self.prompt_type:
+ for k, v in outputs[prompt].items():
+ if len(v) > 0:
+ outputs[prompt][k] = torch.cat(v, dim=0)
+
+ return outputs
diff --git a/eval/main_eval.py b/eval/main_eval.py
new file mode 100755
index 0000000..805bf67
--- /dev/null
+++ b/eval/main_eval.py
@@ -0,0 +1,129 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import torch
+from zim.utils import AverageMeter, print_once
+from .metric import compute_eval_scores, get_gradfilter
+
+def run_eval(
+ config,
+ valloader,
+ evaluator_dict
+):
+ score_dict = {}
+
+ for model_name, evaluator in evaluator_dict.items():
+ print_once(f'\nLOG) {model_name} evaluation start.')
+ model_score_dict = {}
+ for name, loader in valloader.items():
+ print_once(f"LOG) evaluate {model_name} on {name} dataset")
+
+ score = evaluate(
+ name=name,
+ evaluator=evaluator,
+ dataloader=loader,
+ prompt_type=config.eval.prompt_type,
+ use_ddp=config.use_ddp,
+ enable_amp=config.use_amp,
+ model_name=model_name,
+ )
+ model_score_dict[name] = score
+ score_dict[model_name] = model_score_dict
+
+ print_once(f'\nLOG)All evaluation done. Result : ')
+
+ if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+ result = "\n============================================\n"
+
+ for k, v in score_dict.items():
+ print_once('\n')
+ for data_type, log in v.items():
+ for _k, _v in log.items():
+ for __k, __v in _v.items():
+ result += f'Model: {k}'
+ result += f', Prompt: {_k}'
+ result += f', Scale: {__k}'
+ result += f', dataset: {config.dataset.valset} ({data_type})\n'
+ result += f'{__v}\n\n'
+ result += "============================================\n"
+ print_once(result)
+
+
+def evaluate(
+ name,
+ evaluator,
+ dataloader,
+ prompt_type,
+ use_ddp,
+ enable_amp,
+ model_name,
+):
+
+ metric_list = ["l1", "l2", "grad", "conn", "sad"]
+
+ scale_list = ["all", "S", "M", "L"]
+ #scale_list = ["all", ]
+
+ average_metric = {
+ prompt: {
+ scale: {
+ metric_name: AverageMeter(use_ddp) for metric_name in metric_list}
+ for scale in scale_list}
+ for prompt in prompt_type
+ }
+
+ batch_size = dataloader.batch_size
+ device = evaluator.device
+ grad_filter = get_gradfilter(device)
+
+ for _iter, batched_input in enumerate(dataloader):
+ for k, v in batched_input.items():
+ if type(v) == torch.Tensor:
+ batched_input[k] = v.to(device)
+
+ with torch.cuda.amp.autocast(enabled=enable_amp) and torch.no_grad():
+ batched_output = evaluator(batched_input)
+
+ ratio = batched_input['ratio'][0]
+
+ for prompt in prompt_type:
+ logits = batched_output[prompt]["masks"]
+ mattes = batched_input["mattes"]
+
+ scores = compute_eval_scores(
+ logits, mattes, grad_filter,
+ )
+
+ for m in metric_list:
+ average_metric[prompt]["all"][m].update(scores[m], batch_size)
+
+ if "S" in average_metric[prompt] and ratio < 0.01:
+ average_metric[prompt]["S"][m].update(scores[m], batch_size)
+ elif "M" in average_metric[prompt] and ratio < 0.1:
+ average_metric[prompt]["M"][m].update(scores[m], batch_size)
+ elif "L" in average_metric[prompt] and ratio >= 0.1:
+ average_metric[prompt]["L"][m].update(scores[m], batch_size)
+
+ # gather the stats from all processes
+ result_dict = {}
+ for prompt in prompt_type:
+ result_dict[prompt] = {}
+
+ for scale in scale_list:
+ for k, v in average_metric[prompt][scale].items():
+ v.synch(device)
+
+ res = "Result"
+ res += f" | MSE {average_metric[prompt][scale]['l2'].avg:.4f}"
+ res += f" | SAD {average_metric[prompt][scale]['sad'].avg:.4f}"
+ res += f" | MAE {average_metric[prompt][scale]['l1'].avg:.4f}"
+ res += f" | Grad {average_metric[prompt][scale]['grad'].avg:.4f}"
+ res += f" | Conn {average_metric[prompt][scale]['conn'].avg:.4f}"
+
+ result_dict[prompt][scale] = res
+
+ return result_dict
diff --git a/eval/metric.py b/eval/metric.py
new file mode 100755
index 0000000..09c42f4
--- /dev/null
+++ b/eval/metric.py
@@ -0,0 +1,122 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import torch
+import torch.nn.functional as F
+from skimage.measure import label
+import numpy as np
+
+def compute_eval_scores(
+ preds, gts, grad_filter,
+):
+ """
+ preds : (B 1 H W)
+ gts : (B H W)
+ """
+ l1_dist_list = []
+ l2_dist_list = []
+ grad_list = []
+ conn_error_list = []
+ sad_error_list = []
+
+ for pred, gt in zip(preds, gts):
+ gt = gt.unsqueeze(0)
+
+ l1_dist = F.l1_loss(pred, gt) * 1e3
+ l2_dist = F.mse_loss(pred, gt) * 1e3
+ grad = compute_grad(pred, gt, grad_filter) * 1e3
+ sad_error = compute_sad_loss(pred, gt)
+ conn_error = compute_connectivity_error_torch(pred, gt)
+
+ l1_dist_list.append(l1_dist)
+ l2_dist_list.append(l2_dist)
+ grad_list.append(grad)
+ conn_error_list.append(conn_error)
+ sad_error_list.append(sad_error)
+
+ l1_dist = torch.stack(l1_dist_list, dim=0)
+ l2_dist = torch.stack(l2_dist_list, dim=0)
+ grad = torch.stack(grad_list, dim=0)
+ conn_error = torch.stack(conn_error_list, dim=0)
+ sad_error = torch.stack(sad_error_list, dim=0)
+
+ return {
+ "l1": l1_dist.mean().item(),
+ "l2": l2_dist.mean().item(),
+ "grad": grad.mean().item(),
+ "conn": conn_error.mean().item(),
+ "sad": sad_error.mean().item(),
+ }
+
+
+def compute_grad(preds, labels, grad_filter):
+
+ if preds.dim() == 3:
+ preds = preds.unsqueeze(1)
+
+ if labels.dim() == 3:
+ labels = labels.unsqueeze(1)
+
+ grad_preds = F.conv2d(preds, weight=grad_filter, padding=1)
+ grad_labels = F.conv2d(labels, weight=grad_filter, padding=1)
+ grad_preds = torch.sqrt((grad_preds * grad_preds).sum(dim=1, keepdim=True) + 1e-8)
+ grad_labels = torch.sqrt(
+ (grad_labels * grad_labels).sum(dim=1, keepdim=True) + 1e-8
+ )
+
+ return F.l1_loss(grad_preds, grad_labels)
+
+
+def compute_sad_loss(pred, target):
+ error_map = torch.abs((pred - target))
+ loss = torch.sum(error_map)
+
+ return loss / 1000.
+
+
+def getLargestCC(segmentation):
+ segmentation = segmentation.cpu().detach().numpy()
+ labels = label(segmentation, connectivity=1)
+ if labels.max() == 0:
+ return np.zeros_like(segmentation, dtype=bool)
+ largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 # Ignore background label
+ return largestCC
+
+
+def compute_connectivity_error_torch(pred, target, step=0.1):
+ thresh_steps = list(torch.arange(0, 1 + step, step))
+ l_map = torch.ones_like(pred, dtype=torch.float) * -1
+ for i in range(1, len(thresh_steps)):
+ pred_alpha_thresh = (pred >= thresh_steps[i]).to(dtype=torch.int)
+ target_alpha_thresh = (target >= thresh_steps[i]).to(dtype=torch.int)
+
+ omega = torch.from_numpy(getLargestCC(pred_alpha_thresh * target_alpha_thresh)).to(pred.device, dtype=torch.int)
+ flag = ((l_map == -1) & (omega == 0)).to(dtype=torch.int)
+ l_map[flag == 1] = thresh_steps[i - 1]
+
+ l_map[l_map == -1] = 1
+
+ pred_d = pred - l_map
+ target_d = target - l_map
+ pred_phi = 1 - pred_d * (pred_d >= 0.15).to(dtype=torch.int)
+ target_phi = 1 - target_d * (target_d >= 0.15).to(dtype=torch.int)
+ loss = torch.sum(torch.abs(pred_phi - target_phi))
+
+ return loss / 1000.
+
+
+def get_gradfilter(device):
+ """
+ generate gradient filter as the conv kernel
+ """
+ grad_filter = []
+ grad_filter.append([[-1, -2, -1], [0, 0, 0], [1, 2, 1]])
+ grad_filter.append([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]])
+ grad_filter = np.array(grad_filter)
+ grad_filter = np.expand_dims(grad_filter, axis=1)
+ grad_filter = grad_filter.astype(np.float32)
+ return torch.tensor(grad_filter).to(device)
diff --git a/requirements.txt b/requirements.txt
new file mode 100755
index 0000000..bac5016
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+torch
+torchvision
+easydict
+numpy
+opencv-python
+gradio==4.38.1
+gradio-image-prompter
+fastapi==0.112.2
+git+https://github.com/facebookresearch/segment-anything.git
+onnxruntime-gpu==1.17.0
\ No newline at end of file
diff --git a/script/amg.py b/script/amg.py
new file mode 100755
index 0000000..4e9fb36
--- /dev/null
+++ b/script/amg.py
@@ -0,0 +1,124 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+import os, sys
+sys.path.append(os.getcwd())
+
+import argparse
+import numpy as np
+import cv2
+import glob
+from tqdm import tqdm
+
+import torch
+from torch.multiprocessing import Process
+
+from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
+from zim.utils import show_mat_anns
+
+from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
+
+def get_argparser():
+ parser = argparse.ArgumentParser()
+
+ # Path option
+ parser.add_argument("--img_dir", type=str)
+ parser.add_argument("--save_dir", type=str)
+ parser.add_argument("--model", type=str, default='zim,sam')
+ parser.add_argument("--device", type=str, default='cuda')
+ parser.add_argument("--workers", type=int, default=torch.cuda.device_count())
+
+ parser.add_argument("--backbone", type=str, default='vit_b')
+ parser.add_argument("--zim_ckpt", type=str, default=None)
+ parser.add_argument("--sam_ckpt", type=str, default=None)
+
+ parser.add_argument("--points_per_batch", type=int, default=16)
+ parser.add_argument("--pred_iou_thresh", type=float, default=0.6)
+ parser.add_argument("--stability_score_thresh", type=float, default=0.9)
+ parser.add_argument("--stability_score_offset", type=float, default=0.1)
+ parser.add_argument("--box_nms_thresh", type=float, default=0.7)
+ parser.add_argument("--crop_nms_thresh", type=float, default=0.7)
+ return parser
+
+
+def load_zim_amg(args):
+ zim = zim_model_registry[args.backbone](checkpoint=args.zim_ckpt).cuda()
+ mask_generator = ZimAutomaticMaskGenerator(
+ zim,
+ pred_iou_thresh=args.pred_iou_thresh,
+ points_per_batch=args.points_per_batch,
+ stability_score_thresh=args.stability_score_thresh,
+ stability_score_offset=args.stability_score_offset,
+ box_nms_thresh=args.box_nms_thresh,
+ crop_nms_thresh=args.crop_nms_thresh
+ )
+ return mask_generator
+
+def load_sam_amg(args):
+ sam = sam_model_registry[args.backbone](checkpoint=args.sam_ckpt).cuda()
+ mask_generator = SamAutomaticMaskGenerator(
+ sam,
+ pred_iou_thresh=args.pred_iou_thresh,
+ points_per_batch=args.points_per_batch,
+ stability_score_thresh=args.stability_score_thresh,
+ stability_score_offset=args.stability_score_offset,
+ box_nms_thresh=args.box_nms_thresh,
+ crop_nms_thresh=args.crop_nms_thresh
+ )
+ return mask_generator
+
+
+def run_amg(pid, args):
+ with torch.cuda.device(pid):
+
+ mask_generators = []
+ if "zim" in args.model:
+ mask_generators.append(load_zim_amg(args))
+
+ if "sam" in args.model:
+ mask_generators.append(load_sam_amg(args))
+
+ for n, img_path in enumerate(tqdm(img_list)):
+ if (n+1) % args.workers != pid:
+ continue
+
+ image = cv2.imread(img_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ with torch.cuda.amp.autocast(enabled=True):
+ masks = []
+ for mg in mask_generators:
+ m = mg.generate(image)
+ masks.append(m)
+
+ vis = [image]
+ for mask in masks:
+ vis.append(show_mat_anns(image, mask))
+
+ vis = cv2.hconcat(vis)
+
+ save_path = os.path.join(args.save_dir, os.path.basename(img_path))
+ cv2.imwrite(save_path, vis[:, :, ::-1])
+
+
+if __name__ == "__main__":
+
+ args = get_argparser().parse_args()
+ args.model = args.model.split(",")
+
+ img_list = glob.glob(f'{args.img_dir}/**', recursive=True)
+ img_list = [p for p in img_list if p.endswith((".jpg", ".png", ".jpeg"))]
+
+ os.makedirs(args.save_dir, exist_ok=True)
+
+ processes = []
+ for i in range(args.workers):
+ proc = Process(target=run_amg, args=(i, args))
+ processes.append(proc)
+ proc.start()
+ for proc in processes:
+ proc.join()
+
diff --git a/script/evaluation.py b/script/evaluation.py
new file mode 100755
index 0000000..16efd37
--- /dev/null
+++ b/script/evaluation.py
@@ -0,0 +1,90 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+import os, sys
+sys.path.append(os.getcwd())
+
+import warnings
+warnings.filterwarnings(action="ignore")
+
+import torch
+import numpy as np
+import random
+
+from zim.utils import get_parser, print_once
+from config.config import generate_config
+from eval.main_eval import run_eval
+from eval.evaluator import load_sam_evaluator, load_zim_evaluator
+from eval.eval_loader import get_evalloader
+
+torch.backends.cudnn.benchmark = True
+torch.backends.cudnn.enabled = True
+
+def parse_args():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ return args
+
+def main(args):
+
+ config = generate_config(args)
+
+ # Setup random seed
+ torch.manual_seed(config.random_seed)
+ np.random.seed(config.random_seed)
+ random.seed(config.random_seed)
+
+ torch.cuda.set_device(config.local_rank)
+ device = torch.device(f"cuda")
+
+ n_gpus = torch.cuda.device_count()
+ if n_gpus <= 1:
+ config.use_ddp = False
+
+ # DDP init
+ if config.use_ddp:
+ torch.distributed.init_process_group(
+ backend="nccl", rank=config.local_rank, world_size=n_gpus
+ )
+ config.world_size = torch.distributed.get_world_size()
+ device = torch.device(f"cuda:{config.local_rank}")
+
+ print_once("LOG) Initialization start")
+
+ # Dataset list: string to list
+ if isinstance(config.dataset.data_type, str):
+ config.dataset.data_type = config.dataset.data_type.split(",")
+ if isinstance(config.eval.prompt_type, str):
+ config.eval.prompt_type = config.eval.prompt_type.split(",")
+
+ # Benchmarking model list: str to list
+ if isinstance(config.eval.model_list, str):
+ config.eval.model_list = config.eval.model_list.split(",")
+
+ val_loaders = get_evalloader(config)
+
+ # Define SAM
+ print_once("LOG) Start loading models")
+
+ evaluator_dict = {}
+ for model in config.eval.model_list:
+ if model == "sam":
+ evaluator_dict["sam"] = load_sam_evaluator(config, device)
+ elif model == "zim":
+ evaluator_dict["zim"] = load_zim_evaluator(config, device)
+
+ print_once(f"LOG) Loading model {list(evaluator_dict.keys())}")
+ print_once("LOG) Start evaluation")
+ run_eval(
+ config=config,
+ valloader=val_loaders,
+ evaluator_dict=evaluator_dict
+ )
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(vars(args))
diff --git a/script/run_amg.sh b/script/run_amg.sh
new file mode 100755
index 0000000..936fb7c
--- /dev/null
+++ b/script/run_amg.sh
@@ -0,0 +1,22 @@
+img_dir="demo/examples"
+save_dir="demo/amg"
+model="zim,sam"
+
+backbone="vit_b"
+zim_ckpt="results/zim_vit_b_2043.pt"
+sam_ckpt="results/sam_vit_b_01ec64.pth"
+
+points_per_batch=16
+pred_iou_thresh=0.7
+stability_score_thresh=0.9
+
+python script/amg.py \
+--img_dir ${img_dir} \
+--save_dir ${save_dir} \
+--model ${model} \
+--backbone ${backbone} \
+--zim_ckpt ${zim_ckpt} \
+--sam_ckpt ${sam_ckpt} \
+--points_per_batch ${points_per_batch} \
+--pred_iou_thresh ${pred_iou_thresh} \
+--stability_score_thresh ${stability_score_thresh} \
diff --git a/script/run_eval.sh b/script/run_eval.sh
new file mode 100755
index 0000000..7d67044
--- /dev/null
+++ b/script/run_eval.sh
@@ -0,0 +1,34 @@
+amp=True
+data_root=YOUR_DATA_ROOT
+
+# network
+encoder="vit_b"
+decoder="zim"
+
+# evaluation
+workers=4
+image_size=1024
+prompt_type="point,bbox"
+model_list="zim,sam"
+valset="MicroMat3K"
+data_type="fine,coarse"
+data_list_txt="data_list.txt"
+zim_weights="results/zim_vit_b_2043"
+sam_weights="results/sam_vit_b_01ec64.pth"
+
+
+ngpus=$(nvidia-smi --list-gpus | wc -l)
+torchrun --standalone --nnodes=1 --nproc_per_node=${ngpus} script/evaluation.py \
+--amp ${amp} \
+--data-root ${data_root} \
+--network-encoder ${encoder} \
+--network-decoder ${decoder} \
+--eval-workers ${workers} \
+--eval-image-size ${image_size} \
+--eval-prompt-type ${prompt_type} \
+--eval-model-list ${model_list} \
+--eval-zim-weights ${zim_weights} \
+--eval-sam-weights ${sam_weights} \
+--dataset-valset ${valset} \
+--dataset-data-type ${data_type} \
+--dataset-data-list-txt ${data_list_txt} \
diff --git a/zim/__init__.py b/zim/__init__.py
new file mode 100755
index 0000000..dcbfd00
--- /dev/null
+++ b/zim/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .build_model import build_zim_model, zim_model_registry
+from .predictor import ZimPredictor
+from .automatic_mask_generator import ZimAutomaticMaskGenerator
\ No newline at end of file
diff --git a/zim/automatic_mask_generator.py b/zim/automatic_mask_generator.py
new file mode 100755
index 0000000..6780b50
--- /dev/null
+++ b/zim/automatic_mask_generator.py
@@ -0,0 +1,378 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+This source code is based on code from the Segment Anything Model (SAM)
+(https://github.com/facebookresearch/segment-anything).
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area
+
+from typing import Any, Dict, List, Optional, Tuple
+
+from .modeling.zim import Zim
+from .predictor import ZimPredictor
+from .utils.amg import (
+ MaskData,
+ area_from_rle,
+ batch_iterator,
+ batched_mask_to_box,
+ box_xyxy_to_xywh,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ coco_encode_rle,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ mask_to_rle_pytorch,
+ remove_small_regions,
+ rle_to_mask,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+ uncrop_points,
+)
+
+
+class ZimAutomaticMaskGenerator:
+ def __init__(
+ self,
+ model: Zim,
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.9,
+ stability_score_offset: float = 0.1,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+ ) -> None:
+ """
+ Using a SAM model, generates masks for the entire image.
+ Generates a grid of point prompts over the image, then filters
+ low quality and duplicate masks. The default settings are chosen
+ for SAM with a ViT-H backbone.
+
+ Arguments:
+ model (Sam): The SAM model to use for mask prediction.
+ points_per_side (int or None): The number of points to be sampled
+ along one side of the image. The total number of points is
+ points_per_side**2. If None, 'point_grids' must provide explicit
+ point sampling.
+ points_per_batch (int): Sets the number of points run simultaneously
+ by the model. Higher numbers may be faster but use more GPU memory.
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
+ model's predicted mask quality.
+ stability_score_thresh (float): A filtering threshold in [0,1], using
+ the stability of the mask under changes to the cutoff used to binarize
+ the model's mask predictions.
+ stability_score_offset (float): The amount to shift the cutoff when
+ calculated the stability score.
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks.
+ crop_n_layers (int): If >0, mask prediction will be run again on
+ crops of the image. Sets the number of layers to run, where each
+ layer has 2**i_layer number of image crops.
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks between different crops.
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
+ In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ crop_n_points_downscale_factor (int): The number of points-per-side
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ point_grids (list(np.ndarray) or None): A list over explicit grids
+ of points used for sampling, normalized to [0,1]. The nth grid in the
+ list is used in the nth crop layer. Exclusive with points_per_side.
+ min_mask_region_area (int): If >0, postprocessing will be applied
+ to remove disconnected regions and holes in masks with area smaller
+ than min_mask_region_area. Requires opencv.
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+ For large resolutions, 'binary_mask' may consume large amounts of
+ memory.
+ """
+
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), "Exactly one of points_per_side or point_grid must be provided."
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ "binary_mask",
+ "uncompressed_rle",
+ "coco_rle",
+ ], f"Unknown output_mode {output_mode}."
+ if output_mode == "coco_rle":
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
+
+ if min_mask_region_area > 0:
+ import cv2 # type: ignore # noqa: F401
+
+ self.predictor = ZimPredictor(model)
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+
+ @torch.no_grad()
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ """
+ Generates masks for the given image.
+
+ Arguments:
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+ Returns:
+ list(dict(str, any)): A list over records for masks. Each record is
+ a dict containing the following keys:
+ segmentation (dict(str, any) or np.ndarray): The mask. If
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
+ is a dictionary containing the RLE.
+ bbox (list(float)): The box around the mask, in XYWH format.
+ area (int): The area in pixels of the mask.
+ predicted_iou (float): The model's own prediction of the mask's
+ quality. This is filtered by the pred_iou_thresh parameter.
+ point_coords (list(list(float))): The point coordinates input
+ to the model to generate this mask.
+ stability_score (float): A measure of the mask's quality. This
+ is filtered on using the stability_score_thresh parameter.
+ crop_box (list(float)): The crop of the image used to generate
+ the mask, given in XYWH format.
+ """
+
+ # Generate masks
+ mask_data = self._generate_masks(image)
+
+ # Filter small disconnected regions and holes in masks
+ if self.min_mask_region_area > 0:
+ mask_data = self.postprocess_small_regions(
+ mask_data,
+ self.min_mask_region_area,
+ max(self.box_nms_thresh, self.crop_nms_thresh),
+ )
+
+ # Encode masks
+ if self.output_mode == "coco_rle":
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
+ elif self.output_mode == "binary_mask":
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+ else:
+ mask_data["segmentations"] = mask_data["rles"]
+
+ # Write mask records
+ curr_anns = []
+ for idx in range(len(mask_data["segmentations"])):
+ ann = {
+ "segmentation": mask_data["segmentations"][idx],
+ "logit": mask_data["logits"][idx],
+ "area": area_from_rle(mask_data["rles"][idx]),
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
+ "point_coords": [mask_data["points"][idx].tolist()],
+ "stability_score": mask_data["stability_score"][idx].item(),
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+ }
+ curr_anns.append(ann)
+
+ return curr_anns
+
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
+ orig_size = image.shape[:2]
+ crop_boxes, layer_idxs = generate_crop_boxes(
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
+ )
+
+ # Iterate over image crops
+ data = MaskData()
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+ data.cat(crop_data)
+
+ # Remove duplicate masks between crops
+ if len(crop_boxes) > 1:
+ # Prefer masks from smaller crops
+ scores = 1 / box_area(data["crop_boxes"])
+ scores = scores.to(data["boxes"].device)
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ scores,
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.crop_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ data.to_numpy()
+ return data
+
+ def _process_crop(
+ self,
+ image: np.ndarray,
+ crop_box: List[int],
+ crop_layer_idx: int,
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ # Crop the image and calculate embeddings
+ x0, y0, x1, y1 = crop_box
+ cropped_im = image[y0:y1, x0:x1, :]
+ cropped_im_size = cropped_im.shape[:2]
+ self.predictor.set_image(cropped_im)
+
+ # Get points for this crop
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+ # Generate masks for this crop in batches
+ data = MaskData()
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
+ data.cat(batch_data)
+ del batch_data
+ self.predictor.reset_image()
+
+ # Remove duplicates within this crop.
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ data["iou_preds"],
+ torch.zeros_like(data["boxes"][:, 0]), # categories
+ iou_threshold=self.box_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ # Return to the original image frame
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+ data["points"] = uncrop_points(data["points"], crop_box)
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+ return data
+
+ def _process_batch(
+ self,
+ points: np.ndarray,
+ im_size: Tuple[int, ...],
+ crop_box: List[int],
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ orig_h, orig_w = orig_size
+
+ # Run model on this batch
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
+ masks, iou_preds, _ = self.predictor.predict_torch(
+ in_points[:, None, :],
+ in_labels[:, None],
+ multimask_output=True,
+ return_logits=True,
+ )
+
+ # Serialize predictions and store in MaskData
+ data = MaskData(
+ masks=masks.flatten(0, 1),
+ logits=(masks.flatten(0, 1) * 255).byte(),
+ iou_preds=iou_preds.flatten(0, 1),
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
+ )
+ del masks
+
+ # Filter by predicted IoU
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ # Calculate stability score
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+
+ # Threshold masks and calculate boxes
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
+ data["boxes"] = batched_mask_to_box(data["masks"])
+
+ # Filter boxes that touch crop boundaries
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
+ if not torch.all(keep_mask):
+ data.filter(keep_mask)
+
+ # Compress to RLE
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+ data["logits"] = uncrop_masks(data["logits"], crop_box, orig_h, orig_w)
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
+ del data["masks"]
+
+ return data
+
+ @staticmethod
+ def postprocess_small_regions(
+ mask_data: MaskData, min_area: int, nms_thresh: float
+ ) -> MaskData:
+ """
+ Removes small disconnected regions and holes in masks, then reruns
+ box NMS to remove any new duplicates.
+
+ Edits mask_data in place.
+
+ Requires open-cv as a dependency.
+ """
+ if len(mask_data["rles"]) == 0:
+ return mask_data
+
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for rle in mask_data["rles"]:
+ mask = rle_to_mask(rle)
+
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
+ unchanged = not changed
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
+ unchanged = unchanged and not changed
+
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and score=1 to unchanged masks
+ # so NMS will prefer ones that didn't need postprocessing
+ scores.append(float(unchanged))
+
+ # Recalculate boxes and remove any new duplicates
+ masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(masks)
+ keep_by_nms = batched_nms(
+ boxes.float(),
+ torch.as_tensor(scores),
+ torch.zeros_like(boxes[:, 0]), # categories
+ iou_threshold=nms_thresh,
+ )
+
+ # Only recalculate RLEs for masks that have changed
+ for i_mask in keep_by_nms:
+ if scores[i_mask] == 0.0:
+ mask_torch = masks[i_mask].unsqueeze(0)
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
+ mask_data.filter(keep_by_nms)
+
+ return mask_data
\ No newline at end of file
diff --git a/zim/build_model.py b/zim/build_model.py
new file mode 100755
index 0000000..76cb3e6
--- /dev/null
+++ b/zim/build_model.py
@@ -0,0 +1,29 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+This source code is based on code from the Segment Anything Model (SAM)
+(https://github.com/facebookresearch/segment-anything).
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+import os
+import torch
+
+from .modeling.zim import Zim
+from .modeling.encoder import ZIM_Encoder
+from .modeling.decoder import ZIM_Decoder
+
+def build_zim_model(checkpoint):
+
+ encoder = ZIM_Encoder(os.path.join(checkpoint, "encoder.onnx"))
+ decoder = ZIM_Decoder(os.path.join(checkpoint, "decoder.onnx"))
+ net = Zim(encoder, decoder)
+
+ return net
+
+zim_model_registry = {
+ "default": build_zim_model,
+ "vit_l": build_zim_model,
+ "vit_b": build_zim_model,
+}
+
diff --git a/zim/modeling/decoder.py b/zim/modeling/decoder.py
new file mode 100755
index 0000000..c9ded3a
--- /dev/null
+++ b/zim/modeling/decoder.py
@@ -0,0 +1,88 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+import torch
+from typing import Any, Callable
+import onnxruntime
+import numpy as np
+
+def np2tensor(np_array, device):
+ return torch.from_numpy(np_array).to(device)
+
+def tensor2np(torch_tensor):
+ if torch_tensor is None:
+ return None
+
+ return torch_tensor.detach().cpu().numpy()
+
+class ZIM_Decoder():
+ def __init__(self, onnx_path, num_threads=16):
+ self.onnx_path = onnx_path
+
+ sessionOptions = onnxruntime.SessionOptions()
+ sessionOptions.intra_op_num_threads = num_threads
+ sessionOptions.inter_op_num_threads = num_threads
+ providers = ["CPUExecutionProvider"]
+
+ self.ort_session = onnxruntime.InferenceSession(
+ onnx_path, sess_options=sessionOptions, providers=providers
+ )
+ self.num_mask_tokens = 4
+
+ def cuda(self, device_id=0):
+ providers = [
+ (
+ "CUDAExecutionProvider",
+ {
+ "device_id": device_id,
+ },
+ ),
+ ]
+
+ self.ort_session.set_providers(providers)
+
+ def forward(
+ self,
+ interm_feats,
+ image_embeddings,
+ points,
+ boxes,
+ attn_mask,
+ ):
+ device = image_embeddings.device
+
+ ort_inputs = {
+ "feat_D0": tensor2np(interm_feats[0]),
+ "feat_D1": tensor2np(interm_feats[1]),
+ "feat_D2": tensor2np(interm_feats[2]),
+ "image_embeddings": tensor2np(image_embeddings),
+ "attn_mask": tensor2np(attn_mask),
+ }
+
+ if points is not None:
+ point_coords, point_labels = points
+ ort_inputs["point_coords"] = tensor2np(point_coords.float())
+ ort_inputs["point_labels"] = tensor2np(point_labels.float())
+
+ # add paddings as done in SAM
+ padding_point = np.zeros((ort_inputs["point_coords"].shape[0], 1, 2), dtype=np.float32) - 0.5
+ padding_label = -np.ones((ort_inputs["point_labels"].shape[0], 1), dtype=np.float32)
+ ort_inputs["point_coords"] = np.concatenate([ort_inputs["point_coords"], padding_point], axis=1)
+ ort_inputs["point_labels"] = np.concatenate([ort_inputs["point_labels"], padding_label], axis=1)
+
+ if boxes is not None:
+ ort_inputs["point_coords"] = tensor2np(boxes.reshape(-1, 2, 2))
+ ort_inputs["point_labels"] = np.array([[2, 3]], dtype=np.float32).repeat(boxes.shape[0], 0)
+
+ masks, iou_predictions = self.ort_session.run(None, ort_inputs)
+
+ masks = np2tensor(masks, device)
+ iou_predictions = np2tensor(iou_predictions, device)
+
+ return masks, iou_predictions
+
+ __call__: Callable[..., Any] = forward
+
\ No newline at end of file
diff --git a/zim/modeling/encoder.py b/zim/modeling/encoder.py
new file mode 100755
index 0000000..730da07
--- /dev/null
+++ b/zim/modeling/encoder.py
@@ -0,0 +1,60 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+import torch
+from typing import Any, Callable
+import onnxruntime
+
+def np2tensor(np_array, device):
+ return torch.from_numpy(np_array).to(device)
+
+def tensor2np(torch_tensor):
+ return torch_tensor.detach().cpu().numpy()
+
+class ZIM_Encoder():
+ def __init__(self, onnx_path, num_threads=16):
+ self.onnx_path = onnx_path
+
+ sessionOptions = onnxruntime.SessionOptions()
+ sessionOptions.intra_op_num_threads = num_threads
+ sessionOptions.inter_op_num_threads = num_threads
+ providers = ["CPUExecutionProvider"]
+
+ self.ort_session = onnxruntime.InferenceSession(
+ onnx_path, sess_options=sessionOptions, providers=providers
+ )
+
+ def cuda(self, device_id=0):
+ providers = [
+ (
+ "CUDAExecutionProvider",
+ {
+ "device_id": device_id,
+ },
+ ),
+ ]
+
+ self.ort_session.set_providers(providers)
+
+ def forward(
+ self,
+ image,
+ ):
+ device = image.device
+
+ ort_inputs = {
+ "image": tensor2np(image),
+ }
+ image_embeddings, feat_D0, feat_D1, feat_D2 = self.ort_session.run(None, ort_inputs)
+
+ image_embeddings = np2tensor(image_embeddings, device)
+ feat_D0 = np2tensor(feat_D0, device)
+ feat_D1 = np2tensor(feat_D1, device)
+ feat_D2 = np2tensor(feat_D2, device)
+
+ return image_embeddings, (feat_D0, feat_D1, feat_D2)
+
+ __call__: Callable[..., Any] = forward
diff --git a/zim/modeling/zim.py b/zim/modeling/zim.py
new file mode 100755
index 0000000..c4cade5
--- /dev/null
+++ b/zim/modeling/zim.py
@@ -0,0 +1,190 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+This source code is based on code from the Segment Anything Model (SAM)
+(https://github.com/facebookresearch/segment-anything).
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from typing import Any, Dict, List
+
+def gaussian(sigma=6):
+ """
+ 2D Gaussian Kernel Generation.
+ """
+ size = 6 * sigma + 3
+ x = torch.arange(0, size, 1)
+ y = x[:, None]
+ x0, y0 = 3 * sigma + 1, 3 * sigma + 1
+ g = torch.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
+ return g
+
+class Zim(nn.Module):
+ def __init__(
+ self,
+ encoder,
+ decoder,
+ *,
+ image_size: int = 1024,
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
+ ) -> None:
+ """
+ SAM predicts object masks from an image and input prompts.
+
+ Arguments:
+ encoder : The backbone used to encode the
+ image into image embeddings that allow for efficient mask prediction.
+ decoder : Predicts masks from the image embeddings and given prompts.
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
+ """
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.output_activation = nn.Sigmoid()
+
+ self.image_size = image_size
+ self.register_buffer(
+ "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
+ )
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ self.mask_threshold: float = 0.5
+ self.image_format: str = "RGB"
+ self.num_mask_tokens = decoder.num_mask_tokens
+
+ self.encode_stride = 16
+ self.encode_kernel = 21
+ self.attn_mask_size = 64
+ self.g = gaussian(self.encode_kernel)
+
+ self.output_conv = nn.Conv2d(
+ self.num_mask_tokens,
+ self.num_mask_tokens,
+ kernel_size=1, stride=1, padding=0,
+ )
+
+ @property
+ def device(self) -> Any:
+ return self.pixel_mean.device
+
+ def cuda(self, device_id=None):
+ if type(device_id) == torch.device:
+ device_id = device_id.index
+
+ if device_id is None:
+ device_id = 0
+
+ device = torch.device(f"cuda:{device_id}")
+ super(Zim, self).cuda(device)
+
+ self.encoder.cuda(device_id)
+ self.decoder.cuda(device_id)
+
+ return self
+
+ def postprocess_masks(
+ self, masks: torch.Tensor, input_size: List[int], original_size: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Arguments:
+ masks (torch.Tensor): Batched masks from the decoder,
+ in BxCxHxW format.
+ input_size (tuple(int, int)): The size of the image input to the
+ model, in (H, W) format. Used to remove padding.
+ original_size (tuple(int, int)): The original size of the image
+ before resizing for input to the model, in (H, W) format.
+
+ Returns:
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+ is given by original_size.
+ """
+ masks = F.interpolate(
+ masks,
+ (self.image_size, self.image_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ masks = masks[..., : input_size[0], : input_size[1]]
+ masks = F.interpolate(
+ masks, original_size, mode="bilinear", align_corners=False
+ )
+ return masks
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.image_size - h
+ padw = self.image_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
+
+ def bbox_attn_mask(self, boxes):
+ """Prompt-aware Masked Attention: box prompt (binary attn mask) """
+ bs = boxes.shape[0]
+ attn_mask = torch.zeros((bs, self.attn_mask_size, self.attn_mask_size), device=boxes.device)
+
+ # attn_weight = attn_weight.masked_fill(m.logical_not(), -1e4)
+
+ for n in range(bs):
+ xmin, ymin, xmax, ymax = boxes[n]
+
+ xmin, xmax = min(xmin, xmax), max(xmin, xmax)
+ ymin, ymax = min(ymin, ymax), max(ymin, ymax)
+
+ xmin, xmax = int(xmin / self.encode_stride), int(xmax / self.encode_stride)
+ ymin, ymax = int(ymin / self.encode_stride), int(ymax / self.encode_stride)
+
+ xmin, ymin = max(0, xmin), max(0, ymin)
+ xmax = min(self.attn_mask_size, xmax+1)
+ ymax = min(self.attn_mask_size, ymax+1)
+
+ attn_mask[n, ymin:ymax, xmin:xmax] = 1
+
+ return attn_mask
+
+ def point_attn_mask(self, point_coords):
+ """Prompt-aware Masked Attention: point prompt (soft attn mask) """
+ bs = point_coords.shape[0]
+ attn_mask = torch.zeros((bs, self.attn_mask_size, self.attn_mask_size), device=point_coords.device)
+
+ if self.g.device != point_coords.device:
+ self.g = self.g.to(point_coords.device)
+
+ for n in range(bs):
+ for point in point_coords[n]:
+ x, y = int(point[0] / self.encode_stride), int(point[1].item() / self.encode_stride)
+
+ # outside image boundary
+ if x < 0 or y < 0 or x >= self.attn_mask_size or y >= self.attn_mask_size:
+ continue
+
+ # upper left
+ ul = int(round(x - 3 * self.encode_kernel - 1)), int(round(y - 3 * self.encode_kernel - 1))
+ # bottom right
+ br = int(round(x + 3 * self.encode_kernel + 2)), int(round(y + 3 * self.encode_kernel + 2))
+
+ c, d = int(max(0, -ul[0])), int(min(br[0], self.attn_mask_size) - ul[0])
+ a, b = int(max(0, -ul[1])), int(min(br[1], self.attn_mask_size) - ul[1])
+
+ cc, dd = int(max(0, ul[0])), int(min(br[0], self.attn_mask_size))
+ aa, bb = int(max(0, ul[1])), int(min(br[1], self.attn_mask_size))
+
+ attn_mask[n, aa:bb, cc:dd] = torch.maximum(
+ attn_mask[n, aa:bb, cc:dd], self.g[a:b, c:d]
+ )
+
+ return attn_mask
+
+
\ No newline at end of file
diff --git a/zim/predictor.py b/zim/predictor.py
new file mode 100755
index 0000000..09b0ab1
--- /dev/null
+++ b/zim/predictor.py
@@ -0,0 +1,275 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+This source code is based on code from the Segment Anything Model (SAM)
+(https://github.com/facebookresearch/segment-anything).
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torch.nn.parallel import DistributedDataParallel as DDP
+from typing import Optional, Tuple, List
+
+from .utils import ResizeLongestSide
+
+class ZimPredictor:
+ def __init__(
+ self,
+ model,
+ ) -> None:
+ """
+ Uses SAM to calculate the image embedding for an image, and then
+ allow repeated, efficient mask prediction given prompts.
+
+ Arguments:
+ sam_model (Sam): The model to use for mask prediction.
+ """
+ super().__init__()
+ self.model = model.module if isinstance(model, DDP) else model
+ self.transform = ResizeLongestSide(self.model.image_size)
+ self.reset_image()
+
+ def set_image(
+ self,
+ image: np.ndarray,
+ image_format: str = "RGB",
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method.
+
+ Arguments:
+ image (np.ndarray): The image for calculating masks. Expects an
+ image in HWC uint8 format, with pixel values in [0, 255].
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
+ """
+ assert image_format in [
+ "RGB",
+ "BGR",
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+ if image_format != self.model.image_format:
+ image = image[..., ::-1]
+
+ # Transform the image to the form expected by the model
+ input_image = self.transform.apply_image(image)
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
+
+ self.set_torch_image(input_image_torch, image.shape[:2])
+
+ @torch.no_grad()
+ def set_torch_image(
+ self,
+ transformed_image: torch.Tensor,
+ original_image_size: Tuple[int, ...],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method. Expects the input
+ image to be already transformed to the format expected by the model.
+
+ Arguments:
+ transformed_image (torch.Tensor): The input image, with shape
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
+ original_image_size (tuple(int, int)): The size of the image
+ before transformation, in (H, W) format.
+ """
+ assert (
+ len(transformed_image.shape) == 4
+ and transformed_image.shape[1] == 3
+ and max(*transformed_image.shape[2:]) == self.model.image_size
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_size}."
+ self.reset_image()
+
+ self.original_size = original_image_size
+ self.input_size = tuple(transformed_image.shape[-2:])
+ input_image = self.model.preprocess(transformed_image)
+ self.features, self.interm_feats = self.model.encoder(input_image)
+ self.is_image_set = True
+
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ # Transform input prompts
+ coords_torch = None
+ labels_torch = None
+ box_torch = None
+
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.float, device=self.device)
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+ if box is not None:
+ box = self.transform.apply_boxes(box, self.original_size)
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
+
+ masks, iou_predictions, low_res_masks = self.predict_torch(
+ coords_torch,
+ labels_torch,
+ box_torch,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks_np = masks[0].detach().cpu().numpy()
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
+
+ return masks_np, iou_predictions_np, low_res_masks_np
+
+ @torch.no_grad()
+ def predict_torch(
+ self,
+ point_coords: Optional[torch.Tensor],
+ point_labels: Optional[torch.Tensor],
+ boxes: Optional[torch.Tensor] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+ Input prompts are batched torch tensors and are expected to already be
+ transformed to the input frame using ResizeLongestSide.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ if point_coords is not None:
+ points = (point_coords, point_labels)
+ attn_mask = self.model.point_attn_mask(point_coords)
+ else:
+ points = None
+ attn_mask = self.model.bbox_attn_mask(boxes)
+
+ # Embed prompts
+ masks, iou_predictions = self.model.decoder(
+ interm_feats=self.interm_feats,
+ image_embeddings=self.features,
+ points=points,
+ boxes=boxes,
+ attn_mask=attn_mask,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ mask_slice = slice(0, None)
+ else:
+ mask_slice = slice(0, 1)
+
+ masks = masks[:, mask_slice, :, :]
+ iou_predictions = iou_predictions[:, mask_slice]
+
+ low_res_masks = F.interpolate(masks, scale_factor=2, mode='bilinear', align_corners=False)
+
+ masks = self.model.postprocess_masks(
+ masks,
+ input_size=self.input_size,
+ original_size=self.original_size,
+ )
+
+ return masks.sigmoid(), iou_predictions, low_res_masks.sigmoid()
+
+ def get_image_embedding(self) -> torch.Tensor:
+ """
+ Returns the image embeddings for the currently set image, with
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
+ """
+ if not self.is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) to generate an embedding."
+ )
+ assert self.features is not None, "Features must exist if an image has been set."
+ return self.features
+
+ @property
+ def device(self) -> torch.device:
+ return self.model.device
+
+ def reset_image(self) -> None:
+ """Resets the currently set image."""
+ self.is_image_set = False
+ self.features = None
+ self.interm_feats = None
+ self.orig_h = None
+ self.orig_w = None
+ self.input_h = None
+ self.input_w = None
+
\ No newline at end of file
diff --git a/zim/utils/__init__.py b/zim/utils/__init__.py
new file mode 100755
index 0000000..45fb1a8
--- /dev/null
+++ b/zim/utils/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .argparser import get_parser
+from .print import print_once, pretty
+from .utils import AverageMeter, ResizeLongestSide
+from .amg import show_mat_anns
\ No newline at end of file
diff --git a/zim/utils/amg.py b/zim/utils/amg.py
new file mode 100755
index 0000000..5ec1029
--- /dev/null
+++ b/zim/utils/amg.py
@@ -0,0 +1,373 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+This source code is based on code from the Segment Anything Model (SAM)
+(https://github.com/facebookresearch/segment-anything).
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import numpy as np
+import torch
+import cv2
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+
+class MaskData:
+ """
+ A structure for storing masks and their related data in batched format.
+ Implements basic filtering and concatenation.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ for v in kwargs.values():
+ assert isinstance(
+ v, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats = dict(**kwargs)
+
+ def __setitem__(self, key: str, item: Any) -> None:
+ assert isinstance(
+ item, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats[key] = item
+
+ def __delitem__(self, key: str) -> None:
+ del self._stats[key]
+
+ def __getitem__(self, key: str) -> Any:
+ return self._stats[key]
+
+ def items(self) -> ItemsView[str, Any]:
+ return self._stats.items()
+
+ def filter(self, keep: torch.Tensor) -> None:
+ for k, v in self._stats.items():
+ if v is None:
+ self._stats[k] = None
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = v[keep.detach().cpu().numpy()]
+ elif isinstance(v, list) and keep.dtype == torch.bool:
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+ elif isinstance(v, list):
+ self._stats[k] = [v[i] for i in keep]
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def cat(self, new_stats: "MaskData") -> None:
+ for k, v in new_stats.items():
+ if k not in self._stats or self._stats[k] is None:
+ self._stats[k] = deepcopy(v)
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+ elif isinstance(v, list):
+ self._stats[k] = self._stats[k] + deepcopy(v)
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def to_numpy(self) -> None:
+ for k, v in self._stats.items():
+ if isinstance(v, torch.Tensor):
+ self._stats[k] = v.detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+ box_xywh = deepcopy(box_xyxy)
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
+ return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ assert len(args) > 0 and all(
+ len(a) == len(args[0]) for a in args
+ ), "Batched iteration must have inputs of all the same size."
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+ """
+ Encodes masks to an uncompressed RLE, in the format expected by
+ pycoco tools.
+ """
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat(
+ [
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ ]
+ )
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({"size": [h, w], "counts": counts})
+ return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ h, w = rle["size"]
+ mask = np.empty(h * w, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity ^= True
+ mask = mask.reshape(w, h)
+ return mask.transpose() # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+ return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+ """
+ Computes the stability score for a batch of masks. The stability
+ score is the IoU between the binary masks obtained by thresholding
+ the predicted mask logits at high and low values.
+ """
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecessary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ unions = (
+ (masks > (mask_threshold - threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def build_all_layer_point_grids(
+ n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+ """Generates point grids for all crop layers."""
+ points_by_layer = []
+ for i in range(n_layers + 1):
+ n_points = int(n_per_side / (scale_per_layer**i))
+ points_by_layer.append(build_point_grid(n_points))
+ return points_by_layer
+
+
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer
+ has (2**i)**2 boxes for the ith layer.
+ """
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+
+ def crop_len(orig_len, n_crops, overlap):
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return points + offset
+
+
+def uncrop_masks(
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+ mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+ """
+ Removes small disconnected regions and holes in a mask. Returns the
+ mask and an indicator of if the mask has been modified.
+ """
+ import cv2 # type: ignore
+
+ assert mode in ["holes", "islands"]
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ # If every region is below threshold, keep largest
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+ from pycocotools import mask as mask_utils # type: ignore
+
+ h, w = uncompressed_rle["size"]
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
+ return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ if len(shape) > 2:
+ masks = masks.flatten(0, -3)
+ else:
+ masks = masks.unsqueeze(0)
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ if len(shape) > 2:
+ out = out.reshape(*shape[:-2], 4)
+ else:
+ out = out[0]
+
+ return out
+
+def show_mat_anns(image, anns):
+ if len(anns) == 0:
+ return np.zeros_like(image) + 128
+
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
+
+ image = image.astype(np.float32)
+ colorized_mat = np.zeros_like(image)
+
+ for ann in sorted_anns:
+ color = (np.random.random(3) * 255).astype(np.float32)
+ if 'logit' in ann:
+ mat = ann['logit'].astype(np.float32) / 255.
+ else:
+ mat = ann['segmentation'].astype(np.float32)
+
+ color_mat = np.zeros_like(image) + color[None, None]
+ colorized_mat = color_mat * mat[:, :, None] + colorized_mat * (1. - mat[:, :, None])
+
+ colorized_mat = np.uint8(colorized_mat)
+
+ return colorized_mat
\ No newline at end of file
diff --git a/zim/utils/argparser.py b/zim/utils/argparser.py
new file mode 100755
index 0000000..837f9f6
--- /dev/null
+++ b/zim/utils/argparser.py
@@ -0,0 +1,96 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import os
+import argparse
+from config.config import config_
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
+def get_parser(verbose=False):
+ p = argparse.ArgumentParser("argparser", add_help=False)
+
+ p.add_argument(
+ "--data-root", type=str, default=config_.data_root, help="data root directory"
+ )
+ p.add_argument(
+ "--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0"))
+ )
+ p.add_argument(
+ "--amp", type=str2bool, default=True
+ )
+ p.add_argument(
+ "--ddp", action="store_true"
+ )
+ p.add_argument(
+ "--random-seed", type=int, default=config_.random_seed
+ )
+
+ # network config
+ p.add_argument(
+ "--network-encoder",
+ type=str,
+ default=config_.network.encoder,
+ choices=["vit_b", "vit_l"],
+ )
+ p.add_argument(
+ "--network-decoder",
+ type=str,
+ default=config_.network.decoder,
+ choices=["zim", "sam"],
+ )
+ p.add_argument(
+ "--network-encode-kernel",
+ type=int,
+ default=config_.network.encode_kernel,
+ )
+
+ # evaluation config
+ p.add_argument(
+ "--eval-workers", type=int, default=config_.eval.workers,
+ )
+ p.add_argument(
+ "--eval-image-size", type=int, default=config_.eval.image_size,
+ )
+ p.add_argument(
+ "--eval-prompt-type", type=str, default=config_.eval.prompt_type,
+ )
+ p.add_argument(
+ "--eval-model-list", type=str, default=config_.eval.model_list,
+ )
+ p.add_argument(
+ "--eval-zim-weights",
+ type=str,
+ default=config_.eval.zim_weights,
+ )
+ p.add_argument(
+ "--eval-sam-weights",
+ type=str,
+ default=config_.eval.sam_weights,
+ )
+
+ # dataset config
+ p.add_argument(
+ "--dataset-valset", type=str, default=config_.dataset.valset,
+ )
+ p.add_argument(
+ "--dataset-data-type", type=str, default=config_.dataset.data_type,
+ )
+ p.add_argument(
+ "--dataset-data-list-txt", type=str, default=config_.dataset.data_list_txt,
+ )
+
+ return p
diff --git a/zim/utils/print.py b/zim/utils/print.py
new file mode 100755
index 0000000..6c84986
--- /dev/null
+++ b/zim/utils/print.py
@@ -0,0 +1,20 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import torch
+
+def print_once(message):
+ if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+ print(message)
+
+def pretty(d, indent=0):
+ for key, value in d.items():
+ print_once("\t" * indent + str(key))
+ if isinstance(value, dict):
+ pretty(value, indent + 1)
+ else:
+ print_once("\t" * (indent + 1) + str(value))
diff --git a/zim/utils/utils.py b/zim/utils/utils.py
new file mode 100755
index 0000000..c961d39
--- /dev/null
+++ b/zim/utils/utils.py
@@ -0,0 +1,148 @@
+"""
+Copyright (c) 2024-present Naver Cloud Corp.
+
+This source code is licensed under the license found in the
+LICENSE file in the root directory of this source tree.
+"""
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.transforms.functional import resize, to_pil_image, InterpolationMode
+from copy import deepcopy
+from typing import Optional, Tuple, List
+
+class ResizeLongestSide:
+ """
+ Resizes images to the longest side 'target_length', as well as provides
+ methods for resizing coordinates and boxes. Provides methods for
+ transforming both numpy array and batched torch tensors.
+ """
+
+ def __init__(self, target_length: int) -> None:
+ self.target_length = target_length
+
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
+ """
+ Expects a numpy array with shape HxWxC in uint8 format.
+ """
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ return np.array(resize(to_pil_image(image), target_size))
+
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).astype(float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array shape Bx4. Requires the original image size
+ in (H, W) format.
+ """
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
+ """
+ Expects batched images with shape BxCxHxW and float format. This
+ transformation may not exactly match apply_image. apply_image is
+ the transformation expected by the model.
+ """
+ # Expects an image in BCHW format. May not exactly match apply_image.
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
+ return F.interpolate(
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
+ )
+
+ def apply_coords_torch(
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with length 2 in the last dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).to(torch.float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes_torch(
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with shape Bx4. Requires the original image
+ size in (H, W) format.
+ """
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ def apply_mask(self, image: np.ndarray) -> np.ndarray:
+ """
+ Expects a numpy array with shape HxWxC in uint8 format.
+ """
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ return np.array(resize(to_pil_image(image), target_size, interpolation=InterpolationMode.NEAREST))
+
+ @staticmethod
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
+ """
+ Compute the output size given input size and target long side length.
+ """
+ scale = long_side_length * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ neww = int(neww + 0.5)
+ newh = int(newh + 0.5)
+ return (newh, neww)
+
+
+def remove_prefix(text, prefix):
+ if text.startswith(prefix):
+ return text[len(prefix) :]
+ return text
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self, is_ddp):
+ self.is_ddp = is_ddp
+ self.reset()
+
+ def reset(self):
+ self.val = 0.0
+ self.avg = 0.0
+ self.sum = 0.0
+ self.count = 0.0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / (self.count + 1e-5)
+
+ def synch(self, device):
+ if self.is_ddp is False:
+ return
+
+ _sum = torch.tensor(self.sum).to(device)
+ _count = torch.tensor(self.count).to(device)
+
+ torch.distributed.reduce(_sum, dst=0)
+ torch.distributed.reduce(_count, dst=0)
+
+ if torch.distributed.get_rank() == 0:
+ self.sum = _sum.item()
+ self.count = _count.item()
+ self.avg = self.sum / (self.count + 1e-5)