Skip to content

Commit

Permalink
import strategies dynamically; create_progress_logger
Browse files Browse the repository at this point in the history
  • Loading branch information
abrichr committed Jan 13, 2024
1 parent 2a821f3 commit 44d4a55
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 141 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "external/SoM"]
path = external/SoM
url = https://github.com/OpenAdaptAI/SoM
1 change: 1 addition & 0 deletions external/SoM
Submodule SoM added at 45ed34
2 changes: 1 addition & 1 deletion openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"PLOT_PERFORMANCE": True,
# VISUALIZATION CONFIGURATIONS
"VISUALIZE_DARK_MODE": False,
"VISUALIZE_RUN_NATIVELY": True,
"VISUALIZE_RUN_NATIVELY": False,
"VISUALIZE_DENSE_TREES": True,
"VISUALIZE_ANIMATIONS": True,
"VISUALIZE_EXPAND_ALL": False, # not recommended for large trees
Expand Down
7 changes: 4 additions & 3 deletions openadapt/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from loguru import logger


package_name = "openadapt.strategies"
PACKAGE_NAME = "openadapt.strategies"


# Iterate through all modules in the specified package
strategy_by_name = {}
for _, module_name, _ in pkgutil.iter_modules([package_name.replace('.', '/')]):
for _, module_name, _ in pkgutil.iter_modules([PACKAGE_NAME.replace('.', '/')]):
# Import the module
module = importlib.import_module(f"{package_name}.{module_name}")
module = importlib.import_module(f"{PACKAGE_NAME}.{module_name}")
# Filter and add classes ending with 'ReplayStrategy' to the global namespace
names = [
name
Expand All @@ -35,4 +35,5 @@
},
}
logger.info(f"strategy_by_name=\n{pformat(strategy_by_name)}")

globals().update(strategy_by_name)
48 changes: 2 additions & 46 deletions openadapt/strategies/mixins/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class MyReplayStrategy(SAMReplayStrategyMixin):
modeling,
sam_model_registry,
)
import fire
import matplotlib.axes as axes
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -78,8 +77,8 @@ def __init__(
# from https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb
# TODO: use points_grid instead, masked by active_window (for performance)
points_per_side=64,
pred_iou_thresh=0.86,
stability_score_thresh=0.95,#0.92,
#pred_iou_thresh=0.86,
#stability_score_thresh=0.95,#0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
# TODO: determine dynamically based on screenshot size
Expand Down Expand Up @@ -383,46 +382,3 @@ def initialize_sam_model(
logger.info(f"unlinking {checkpoint_file_path=}")
checkpoint_file_path.unlink()
return initialize_sam_model(model_name, checkpoint_dir_path)

def create_progress_logger(msg, interval=.1):
"""
Creates a progress logger function with a specified reporting interval.
Args:
interval (float): Every nth % at which to update progress
Returns:
callable: function to pass into urllib.request.urlretrieve as reporthook
"""
last_reported_percent = 0

def download_progress(block_num, block_size, total_size):
nonlocal last_reported_percent
downloaded = block_num * block_size
if total_size > 0:
percent = (downloaded / total_size) * 100
if percent - last_reported_percent >= interval:
sys.stdout.write(f"\r{percent:.1f}% {msg}")
sys.stdout.flush()
last_reported_percent = percent
else:
sys.stdout.write(f"\rDownloaded {downloaded} bytes")
sys.stdout.flush()

return download_progress


def run_on_image(image_path: str):

class DummyReplayStrategy(SAMReplayStrategyMixin):
def get_next_action_event():
pass

logger.info(f"{image_path=}")
image = Image.open(image_path).convert("RGB")
sam = DummyReplayStrategy(None)
sam.get_image_bboxes(image)


if __name__ == "__main__":
fire.Fire(run_on_image)
28 changes: 28 additions & 0 deletions openadapt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,5 +789,33 @@ def get_functions(name: str) -> dict:
return functions


def create_progress_logger(msg, interval=.1):
"""
Creates a progress logger function with a specified reporting interval.
Args:
interval (float): Every nth % at which to update progress
Returns:
callable: function to pass into urllib.request.urlretrieve as reporthook
"""
last_reported_percent = 0

def download_progress(block_num, block_size, total_size):
nonlocal last_reported_percent
downloaded = block_num * block_size
if total_size > 0:
percent = (downloaded / total_size) * 100
if percent - last_reported_percent >= interval:
sys.stdout.write(f"\r{percent:.1f}% {msg}")
sys.stdout.flush()
last_reported_percent = percent
else:
sys.stdout.write(f"\rDownloaded {downloaded} bytes")
sys.stdout.flush()

return download_progress


if __name__ == "__main__":
fire.Fire(get_functions(__name__))
Loading

0 comments on commit 44d4a55

Please sign in to comment.