Skip to content

Commit

Permalink
Enable uniform range for subsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jan 25, 2024
1 parent 962907c commit a7cb692
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions direct/common/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,34 @@ def __init__(
self.accelerations = accelerations

self.uniform_range = uniform_range
if uniform_range and (len(center_fractions) != 2 or len(accelerations) != 2):
raise ValueError(
f"When `uniform_range` is True, both `center_fractions` and `accelerations` should have "
f"a length of two. Received center_fractions={center_fractions} and accelerations={accelerations}."
)

self.rng = np.random.RandomState()

def choose_acceleration(self):
if not self.accelerations:
return None

if not self.uniform_range:
if self.uniform_range:
acceleration = self.rng.uniform(low=min(self.accelerations), high=max(self.accelerations), size=1)[0]
if self.center_fractions is None:
return acceleration
center_fraction = self.rng.uniform(
low=min(self.center_fractions), high=max(self.center_fractions), size=1
)[0]
center_fraction = min(acceleration / 100, center_fraction)
else:
choice = self.rng.randint(0, len(self.accelerations))
acceleration = self.accelerations[choice]
if self.center_fractions is None:
return acceleration

center_fraction = self.center_fractions[choice]
return center_fraction, acceleration
raise NotImplementedError("Uniform range is not yet implemented.")
return center_fraction, acceleration

@abstractmethod
def mask_func(self, *args, **kwargs):
Expand Down

0 comments on commit a7cb692

Please sign in to comment.