Skip to content

Commit

Permalink
Merge pull request #582 from roboflow/sam2-id-free-caching
Browse files Browse the repository at this point in the history
Adds in ability to cache low res logits from prompts
  • Loading branch information
PawelPeczek-Roboflow authored Aug 22, 2024
2 parents 37623f3 + 93241f3 commit d3be171
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 21 deletions.
2 changes: 1 addition & 1 deletion CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
* @PawelPeczek-Roboflow @grzegorz-roboflow @yeldarby @probicheaux @hansent
docs/ @capjamesg
docs/ @capjamesg @PawelPeczek-Roboflow @grzegorz-roboflow
3 changes: 3 additions & 0 deletions docs/server_configuration/environmental_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ Below is a list of some environmental values that require more in-depth explanat
Environmental variable | Default | Description
------------------------------------------ | ------------------------------------------------------------------------ | -----------
ONNXRUNTIME_EXECUTION_PROVIDERS | "[CUDAExecutionProvider,OpenVINOExecutionProvider,CPUExecutionProvider]" | List of execution providers in priority order, warning message will be displayed if provider is not supported on user platform
SAM2_MAX_EMBEDDING_CACHE_SIZE | 100 | The number of sam2 embeddings that will be held in memory. The embeddings will be held in gpu memory. Each embedding takes 16777216 bytes.
SAM2_MAX_LOGITS_CACHE_SIZE | 1000 | The number of sam2 logits that will be held in memory. The the logits will be in cpu memory. Each logit takes 262144 bytes.
DISABLE_SAM2_LOGITS_CACHE | False | If set to True, disables the caching of SAM2 logits. This can be useful for debugging or in scenarios where memory usage needs to be minimized, but may result in slower performance for repeated similar requests.
42 changes: 33 additions & 9 deletions inference/core/entities/requests/sam2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Tuple, Union

from pydantic import BaseModel, Field, root_validator, validator

Expand Down Expand Up @@ -68,18 +68,29 @@ class Point(BaseModel):
y: float
positive: bool

def to_hashable(self) -> Tuple[float, float, bool]:
return (self.x, self.y, self.positive)


class Sam2Prompt(BaseModel):
box: Optional[Box] = Field(default=None)
points: Optional[List[Point]] = Field(default=None)

def num_points(self) -> int:
return len(self.points or [])


class Sam2PromptSet(BaseModel):
prompts: Optional[List[Sam2Prompt]] = Field(
default=None,
description="An optional list of prompts for masks to predict. Each prompt can include a bounding box and / or a set of postive or negative points",
)

def num_points(self) -> int:
if not self.prompts:
return 0
return sum(prompt.num_points() for prompt in self.prompts)

def to_sam2_inputs(self):
if self.prompts is None:
return {"point_coords": None, "point_labels": None, "box": None}
Expand All @@ -98,16 +109,16 @@ def to_sam2_inputs(self):
return_dict["point_labels"].append(
list(int(point.positive) for point in prompt.points)
)
else:
return_dict["point_coords"].append([])
return_dict["point_labels"].append([])

return_dict = {k: v if v else None for k, v in return_dict.items()}
lengths = set()
for v in return_dict.values():
if isinstance(v, list):
lengths.add(len(v))

if not len(lengths) in [0, 1]:
raise ValueError("All prompts must have the same number of points")
if not any(return_dict["point_coords"]):
return_dict["point_coords"] = None
if not any(return_dict["point_labels"]):
return_dict["point_labels"] = None

return_dict = {k: v if v else None for k, v in return_dict.items()}
return return_dict


Expand Down Expand Up @@ -150,3 +161,16 @@ class Sam2SegmentationRequest(Sam2InferenceRequest):
"to select the best mask. For non-ambiguous prompts, such as multiple "
"input prompts, multimask_output=False can give better results.",
)

save_logits_to_cache: bool = Field(
default=False,
description="If True, saves the low-resolution logits to the cache for potential future use. "
"This can speed up subsequent requests with similar prompts on the same image. "
"This feature is ignored if DISABLE_SAM2_LOGITS_CACHE env variable is set True",
)
load_logits_from_cache: bool = Field(
default=False,
description="If True, attempts to load previously cached low-resolution logits for the given image and prompt set. "
"This can significantly speed up inference when making multiple similar requests on the same image. "
"This feature is ignored if DISABLE_SAM2_LOGITS_CACHE env variable is set True",
)
4 changes: 4 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@
# Maximum embedding cache size for SAM, default is 10
SAM_MAX_EMBEDDING_CACHE_SIZE = int(os.getenv("SAM_MAX_EMBEDDING_CACHE_SIZE", 10))

SAM2_MAX_EMBEDDING_CACHE_SIZE = int(os.getenv("SAM2_MAX_EMBEDDING_CACHE_SIZE", 100))
SAM2_MAX_LOGITS_CACHE_SIZE = int(os.getenv("SAM2_MAX_LOGITS_CACHE_SIZE", 1000))
DISABLE_SAM2_LOGITS_CACHE = str2bool(os.getenv("DISABLE_SAM2_LOGITS_CACHE", False))

# SAM version ID, default is "vit_h"
SAM_VERSION_ID = os.getenv("SAM_VERSION_ID", "vit_h")
SAM2_VERSION_ID = os.getenv("SAM2_VERSION_ID", "hiera_large")
Expand Down
172 changes: 165 additions & 7 deletions inference/models/sam2/segment_anything2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import base64
import copy
import hashlib
from io import BytesIO
from time import perf_counter
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union

import numpy as np
import rasterio.features
Expand All @@ -22,6 +23,7 @@
from inference.core.entities.requests.sam2 import (
Sam2EmbeddingRequest,
Sam2InferenceRequest,
Sam2Prompt,
Sam2PromptSet,
Sam2SegmentationRequest,
)
Expand All @@ -30,7 +32,13 @@
Sam2SegmentationPrediction,
Sam2SegmentationResponse,
)
from inference.core.env import DEVICE, SAM2_VERSION_ID, SAM_MAX_EMBEDDING_CACHE_SIZE
from inference.core.env import (
DEVICE,
DISABLE_SAM2_LOGITS_CACHE,
SAM2_MAX_EMBEDDING_CACHE_SIZE,
SAM2_MAX_LOGITS_CACHE_SIZE,
SAM2_VERSION_ID,
)
from inference.core.models.roboflow import RoboflowCoreModel
from inference.core.utils.image_utils import load_image_rgb
from inference.core.utils.postprocess import masks2poly
Expand All @@ -39,6 +47,11 @@
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"


class LogitsCacheType(TypedDict):
logits: np.ndarray
prompt_set: Sam2PromptSet


class SegmentAnything2(RoboflowCoreModel):
"""SegmentAnything class for handling segmentation tasks.
Expand Down Expand Up @@ -75,6 +88,8 @@ def __init__(self, *args, model_id: str = f"sam2/{SAM2_VERSION_ID}", **kwargs):
self.embedding_cache = {}
self.image_size_cache = {}
self.embedding_cache_keys = []
self.low_res_logits_cache: Dict[Tuple[str, str], LogitsCacheType] = {}
self.low_res_logits_cache_keys = []

self.task_type = "unsupervised-segmentation"

Expand Down Expand Up @@ -108,7 +123,7 @@ def embed_image(
Notes:
- Embeddings and image sizes are cached to improve performance on repeated requests for the same image.
- The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
- The cache has a maximum size defined by SAM2_MAX_CACHE_SIZE. When the cache exceeds this size,
the oldest entries are removed.
Example:
Expand All @@ -127,7 +142,7 @@ def embed_image(
if image_id is None:
image_id = hashlib.md5(img_in.tobytes()).hexdigest()[:12]

if image_id and image_id in self.embedding_cache:
if image_id in self.embedding_cache:
return (
self.embedding_cache[image_id],
self.image_size_cache[image_id],
Expand All @@ -140,8 +155,10 @@ def embed_image(

self.embedding_cache[image_id] = embedding_dict
self.image_size_cache[image_id] = img_in.shape[:2]
if image_id in self.embedding_cache_keys:
self.embedding_cache_keys.remove(image_id)
self.embedding_cache_keys.append(image_id)
if len(self.embedding_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE:
if len(self.embedding_cache_keys) > SAM2_MAX_EMBEDDING_CACHE_SIZE:
cache_key = self.embedding_cache_keys.pop(0)
del self.embedding_cache[cache_key]
del self.image_size_cache[cache_key]
Expand Down Expand Up @@ -201,8 +218,10 @@ def segment_image(
image: Optional[InferenceRequestImage],
image_id: Optional[str] = None,
prompts: Optional[Union[Sam2PromptSet, dict]] = None,
mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
multimask_output: Optional[bool] = True,
mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None,
save_logits_to_cache: bool = False,
load_logits_from_cache: bool = False,
**kwargs,
):
"""
Expand All @@ -213,9 +232,11 @@ def segment_image(
image (Any): The image to be segmented.
image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks.
prompts (Optional[List[Sam2Prompt]]): List of prompts to use for segmentation. Defaults to None.
mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input low_res_logits for the image.
multimask_output: (bool): Flag to decide if multiple masks proposal to be predicted (among which the most
promising will be returned
)
use_logits_cache: (bool): Flag to decide to use cached logits from prior prompting
**kwargs: Additional keyword arguments.
Returns:
Expand All @@ -236,6 +257,10 @@ def segment_image(
- The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size,
the oldest entries are removed.
"""
load_logits_from_cache = (
load_logits_from_cache and not DISABLE_SAM2_LOGITS_CACHE
)
save_logits_to_cache = save_logits_to_cache and not DISABLE_SAM2_LOGITS_CACHE
with torch.inference_mode():
if image is None and not image_id:
raise ValueError("Must provide either image or cached image_id")
Expand All @@ -252,12 +277,23 @@ def segment_image(
self.predictor._orig_hw = [original_image_size]
self.predictor._is_batch = False
args = dict()
prompt_set: Sam2PromptSet
if prompts:
if type(prompts) is dict:
args = Sam2PromptSet(**prompts).to_sam2_inputs()
prompt_set = Sam2PromptSet(**prompts)
args = prompt_set.to_sam2_inputs()
else:
prompt_set = prompts
args = prompts.to_sam2_inputs()
else:
prompt_set = Sam2PromptSet()

if mask_input is None and load_logits_from_cache:
mask_input = maybe_load_low_res_logits_from_cache(
image_id, prompt_set, self.low_res_logits_cache
)

args = pad_points(args)
masks, scores, low_resolution_logits = self.predictor.predict(
mask_input=mask_input,
multimask_output=multimask_output,
Expand All @@ -271,8 +307,105 @@ def segment_image(
low_resolution_logits=low_resolution_logits,
)

if save_logits_to_cache:
self.add_low_res_logits_to_cache(
low_resolution_logits, image_id, prompt_set
)

return masks, scores, low_resolution_logits

def add_low_res_logits_to_cache(
self, logits: np.ndarray, image_id: str, prompt_set: Sam2PromptSet
) -> None:
logits = logits[:, None, :, :]
prompt_id = hash_prompt_set(image_id, prompt_set)
self.low_res_logits_cache[prompt_id] = {
"logits": logits,
"prompt_set": prompt_set,
}
if prompt_id in self.low_res_logits_cache_keys:
self.low_res_logits_cache_keys.remove(prompt_id)
self.low_res_logits_cache_keys.append(image_id)
if len(self.low_res_logits_cache_keys) > SAM2_MAX_LOGITS_CACHE_SIZE:
cache_key = self.low_res_logits_cache_keys.pop(0)
del self.low_res_logits_cache[cache_key]


def hash_prompt_set(image_id: str, prompt_set: Sam2PromptSet) -> Tuple[str, str]:
"""Computes unique hash from a prompt set."""
md5_hash = hashlib.md5()
md5_hash.update(str(prompt_set).encode("utf-8"))
return image_id, md5_hash.hexdigest()[:12]


def maybe_load_low_res_logits_from_cache(
image_id: str,
prompt_set: Sam2PromptSet,
cache: Dict[Tuple[str, str], LogitsCacheType],
) -> Optional[np.ndarray]:
"Loads prior masks from the cache by searching over possibel prior prompts."
prompts = prompt_set.prompts
if not prompts:
return None

return find_prior_prompt_in_cache(prompt_set, image_id, cache)


def find_prior_prompt_in_cache(
initial_prompt_set: Sam2PromptSet,
image_id: str,
cache: Dict[Tuple[str, str], LogitsCacheType],
) -> Optional[np.ndarray]:
"""
Performs search over the cache to see if prior used prompts are subset of this one.
"""

logits_for_image = [cache[k] for k in cache if k[0] == image_id]
maxed_size = 0
best_match: Optional[np.ndarray] = None
desired_size = initial_prompt_set.num_points() - 1
for cached_dict in logits_for_image[::-1]:
logits = cached_dict["logits"]
prompt_set: Sam2PromptSet = cached_dict["prompt_set"]
is_viable = is_prompt_strict_subset(prompt_set, initial_prompt_set)
if not is_viable:
continue

size = prompt_set.num_points()
# short circuit search if we find prompt with one less point (most recent possible mask)
if size == desired_size:
return logits
if size >= maxed_size:
maxed_size = size
best_match = logits

return best_match


def is_prompt_strict_subset(
prompt_set_sub: Sam2PromptSet, prompt_set_super: Sam2PromptSet
) -> bool:
if prompt_set_sub == prompt_set_super:
return False

super_copy = [p for p in prompt_set_super.prompts]
for prompt_sub in prompt_set_sub.prompts:
found_match = False
for prompt_super in super_copy:
is_sub = prompt_sub.box == prompt_super.box
is_sub = is_sub and set(
p.to_hashable() for p in prompt_sub.points or []
) <= set(p.to_hashable() for p in prompt_super.points or [])
if is_sub:
super_copy.remove(prompt_super)
found_match = True
break
if not found_match:
return False

# every prompt in prompt_set_sub has a matching super prompt
return True


def choose_most_confident_sam_prediction(
masks: np.ndarray,
Expand Down Expand Up @@ -353,3 +486,28 @@ def turn_segmentation_results_into_api_response(
time=perf_counter() - inference_start_timestamp,
predictions=predictions,
)


def pad_points(args: Dict[str, Any]) -> Dict[str, Any]:
"""
Pad arguments to be passed to sam2 model with not_a_point label (-1).
This is necessary when there are multiple prompts per image so that a tensor can be created.
Also pads empty point lists with a dummy non-point entry.
"""
args = copy.deepcopy(args)
if args["point_coords"] is not None:
max_len = max(max(len(prompt) for prompt in args["point_coords"]), 1)
for prompt in args["point_coords"]:
for _ in range(max_len - len(prompt)):
prompt.append([0, 0])
for label in args["point_labels"]:
for _ in range(max_len - len(label)):
label.append(-1)
else:
if args["point_labels"] is not None:
raise ValueError(
"Can't have point labels without corresponding point coordinates"
)
return args
Binary file not shown.
5 changes: 5 additions & 0 deletions tests/inference/models_predictions_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BEER_IMAGE_PATH = os.path.join(ASSETS_DIR, "beer.jpg")
TRUCK_IMAGE_PATH = os.path.join(ASSETS_DIR, "truck.jpg")
SAM2_TRUCK_LOGITS = os.path.join(ASSETS_DIR, "low_res_logits.npy")
SAM2_TRUCK_MASK_FROM_CACHE = os.path.join(ASSETS_DIR, "mask_from_cached_logits.npy")


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -190,6 +191,10 @@ def sam2_small_model() -> Generator[str, None, None]:
def sam2_small_truck_logits() -> Generator[np.ndarray, None, None]:
yield np.load(SAM2_TRUCK_LOGITS)

@pytest.fixture(scope="function")
def sam2_small_truck_mask_from_cached_logits() -> Generator[np.ndarray, None, None]:
yield np.load(SAM2_TRUCK_MASK_FROM_CACHE)


def fetch_and_place_model_in_cache(
model_id: str,
Expand Down
Loading

0 comments on commit d3be171

Please sign in to comment.