Skip to content

Commit

Permalink
Choise for target accel in transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Feb 5, 2024
1 parent fbac620 commit e95b890
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 5 deletions.
4 changes: 2 additions & 2 deletions direct/common/subsample_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
@dataclass
class MaskingConfig(BaseConfig):
name: str = MISSING
accelerations: Tuple[int, ...] = (5,) # Ideally Union[float, int].
accelerations: Tuple[float, ...] = (5.0,) # Ideally Union[float, int].
center_fractions: Optional[Tuple[float, ...]] = (0.1,) # Ideally Optional[Tuple[float, ...]]
uniform_range: bool = False
image_center_crop: bool = False
dynamic_mask: Optional[bool] = None

val_accelerations: Tuple[int, ...] = (5, 10)
val_accelerations: Tuple[float, ...] = (5.0, 10.0)
val_center_fractions: Optional[Tuple[float, ...]] = (0.1, 0.05)
1 change: 1 addition & 0 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class TransformsConfig(BaseConfig):
image_recon_type: str = "rss"
pad_coils: Optional[int] = None
use_seed: bool = True
target_accelerations: Optional[tuple[float, ...]] = None


@dataclass
Expand Down
28 changes: 27 additions & 1 deletion direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:

for i in range(nz):
sampling_mask_z, acceleration_z, center_fraction_z = self.mask_func(
shape=shape, seed=seed, return_acs=False, return_acceleration=True
shape=shape, seed=dynamic_seeds[i], return_acs=False, return_acceleration=True
)

sampling_mask.append(sampling_mask_z.to(sample["kspace"].dtype))
Expand Down Expand Up @@ -1642,6 +1642,30 @@ def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample


class AddTargetAcceleration(DirectTransform):
"""This will find the sample acceleration provided by the mask function.
Then it will add a tensor with the corresponding acceleration in the target accelerations."""
def __init__(self, mask_func: Callable, target_accelerations: tuple[int, ...]):
super().__init__()
if mask_func.uniform_range:
raise ValueError(f"Cannot apply this transform for `uniform_range`=True for the mask function.")
self.mask_func_accelerations = mask_func.accelerations
self.target_accelerations = target_accelerations

def __call__(self, sample: Dict[str, Any]):
# Convert tensor to Python scalar
sample_acceleration = sample["acceleration"].item()

# Find the index of the value in the list
ind = self.mask_func_accelerations.index(sample_acceleration)

target_acceleration = self.target_accelerations[ind]
sample["acceleration"] = torch.tensor([target_acceleration], dtype=sample["acceleration"].dtype)

return sample


class ModuleWrapper:
class SubWrapper:
def __init__(self, transform, toggle_dims):
Expand Down Expand Up @@ -2055,6 +2079,7 @@ def build_mri_transforms(
scaling_key: TransformKey = TransformKey.MASKED_KSPACE,
scale_percentile: Optional[float] = 0.99,
use_seed: bool = True,
target_accelerations: Optional[tuple[int,...]] = None,
) -> object:
"""Build transforms for MRI.
Expand Down Expand Up @@ -2228,6 +2253,7 @@ def build_mri_transforms(
return_acs=estimate_sensitivity_maps,
dynamic_mask=dynamic_mask,
),
AddTargetAcceleration(mask_func, target_accelerations),
]
if compute_and_apply_padding:
mri_transforms += [ApplyZeroPadding("sampling_mask", "padding")]
Expand Down
2 changes: 2 additions & 0 deletions direct/nn/adaptive/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class PolicyConfig(ModelConfig):
use_softplus: bool = True
slope: float = 10
fix_sign_leakage: bool = True
acceleration: Optional[float] = None
center_fraction: Optional[float] = None


@dataclass
Expand Down
49 changes: 49 additions & 0 deletions direct/nn/adaptive/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
fix_sign_leakage: bool = True,
st_slope: float = 10,
st_clamp: bool = False,
acceleration: Optional[float] = None,
center_fraction: Optional[float] = None,
):
"""Inits :class:`ParameterizedPolicy`.
Expand Down Expand Up @@ -129,6 +131,9 @@ def __init__(

self.sampling_type = sampling_type

self.acceleration = acceleration
self.center_fraction = center_fraction


class ParameterizedStaticPolicy(ParameterizedPolicy):
"""Base Parameterized policy model for non dynamic 2D or 3D data."""
Expand All @@ -142,6 +147,8 @@ def __init__(
fix_sign_leakage: bool = True,
st_slope: float = 10,
st_clamp: bool = False,
acceleration: Optional[float] = None,
center_fraction: Optional[float] = None,
):
"""Inits :class:`ParameterizedStaticPolicy`.
Expand Down Expand Up @@ -173,6 +180,8 @@ def __init__(
fix_sign_leakage=fix_sign_leakage,
st_slope=st_slope,
st_clamp=st_clamp,
acceleration=acceleration,
center_fraction=center_fraction,
)

@abstractmethod
Expand Down Expand Up @@ -244,6 +253,16 @@ def forward(
nonzero_idcs = (mask == 0).nonzero(as_tuple=True)
probs_to_norm = masked_prob_mask[nonzero_idcs].reshape(batch_size, -1)

if (self.acceleration is not None) and (self.center_fraction is not None):
acceleration = self.acceleration
center_fraction = center_fraction
else:
if (acceleration is None) or (center_fraction is None):
raise ValueError(f"One of `acceleration` or `center_fraction` received None for a value. "
f"This should not be None when `StraightThroughPolicy` is initialized "
f"with `acceleration` or `center_fraction` with None value."
)

# Rescale probabilities to desired sparsity.
budget = self.num_actions / acceleration - self.num_actions * center_fraction
if isinstance(budget, float):
Expand Down Expand Up @@ -289,6 +308,8 @@ def __init__(
fix_sign_leakage: bool = True,
st_slope: float = 10,
st_clamp: bool = False,
acceleration: Optional[float] = None,
center_fraction: Optional[float] = None,
):
super().__init__(
kspace_shape=kspace_shape,
Expand All @@ -298,6 +319,8 @@ def __init__(
fix_sign_leakage=fix_sign_leakage,
st_slope=st_slope,
st_clamp=st_clamp,
acceleration=acceleration,
center_fraction=center_fraction,
)

def dim_check(self, kspace: torch.Tensor) -> None:
Expand All @@ -317,6 +340,8 @@ def __init__(
fix_sign_leakage: bool = True,
st_slope: float = 10,
st_clamp: bool = False,
acceleration: Optional[float] = None,
center_fraction: Optional[float] = None,
):
super().__init__(
kspace_shape=kspace_shape,
Expand All @@ -326,6 +351,8 @@ def __init__(
fix_sign_leakage=fix_sign_leakage,
st_slope=st_slope,
st_clamp=st_clamp,
acceleration=acceleration,
center_fraction=center_fraction,
)

def dim_check(self, kspace: torch.Tensor) -> None:
Expand All @@ -348,6 +375,8 @@ def __init__(
fix_sign_leakage: bool = True,
st_slope: float = 10,
st_clamp: bool = False,
acceleration: Optional[float] = None,
center_fraction: Optional[float] = None,
):
"""Inits :class:`ParameterizedDynamicOrMultislice2dPolicy`.
Expand Down Expand Up @@ -385,6 +414,8 @@ def __init__(
fix_sign_leakage=fix_sign_leakage,
st_slope=st_slope,
st_clamp=st_clamp,
acceleration=acceleration,
center_fraction=center_fraction,
)

def forward(
Expand Down Expand Up @@ -423,6 +454,16 @@ def forward(
batch_size, _, slices, height, width, _ = kspace.shape # batch, coils, time, height, width, complex
masks = [mask.expand(batch_size, 1, slices, height, width, 1)]

if (self.acceleration is not None) and (self.center_fraction is not None):
acceleration = self.acceleration
center_fraction = center_fraction
else:
if (acceleration is None) or (center_fraction is None):
raise ValueError(f"One of `acceleration` or `center_fraction` received None for a value. "
f"This should not be None when `StraightThroughPolicy` is initialized "
f"with `acceleration` or `center_fraction` with None value."
)

budget = self.num_actions / acceleration - self.num_actions * center_fraction
if isinstance(budget, float):
budget = int(budget)
Expand Down Expand Up @@ -574,6 +615,8 @@ def __init__(
st_slope: float = 10,
st_clamp: bool = False,
non_uniform: bool = False,
acceleration: Optional[float] = None,
center_fraction: Optional[float] = None,
):
"""Inits :class:`ParameterizedDynamic2dPolicy`.
Expand Down Expand Up @@ -608,6 +651,8 @@ def __init__(
fix_sign_leakage=fix_sign_leakage,
st_slope=st_slope,
st_clamp=st_clamp,
acceleration=acceleration,
center_fraction=center_fraction,
)
self.non_uniform = non_uniform

Expand All @@ -626,6 +671,8 @@ def __init__(
st_slope: float = 10,
st_clamp: bool = False,
non_uniform: bool = False,
acceleration: Optional[float] = None,
center_fraction: Optional[float] = None,
):
"""Inits :class:`ParameterizedMultislice2dPolicy`.
Expand Down Expand Up @@ -660,5 +707,7 @@ def __init__(
fix_sign_leakage=fix_sign_leakage,
st_slope=st_slope,
st_clamp=st_clamp,
acceleration=acceleration,
center_fraction=center_fraction,
)
self.non_uniform = non_uniform
19 changes: 17 additions & 2 deletions direct/nn/adaptive/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,8 @@ def __init__(
sampler_cwn_conv: bool = False,
num_time_steps: Optional[int] = None,
num_slices: Optional[int] = None,
acceleration: Optional[float] = None,
center_fraction: Optional[float] = None,
):
super().__init__()

Expand Down Expand Up @@ -572,14 +574,17 @@ def __init__(
for _ in range(num_layers):
self.layers.append(st_policy_block(**st_policy_block_kwargs))

self.acceleration = acceleration
self.center_fraction = center_fraction

def forward(
self,
masked_kspace: torch.Tensor,
mask: torch.Tensor,
sensitivity_map: torch.Tensor,
kspace: torch.Tensor,
acceleration: float | torch.Tensor,
center_fraction: float | torch.Tensor,
acceleration: Optional[float | torch.Tensor] = None,
center_fraction: Optional[float | torch.Tensor] = None,
padding: Optional[torch.Tensor] = None,
):
if self.sampling_type in [
Expand All @@ -594,6 +599,16 @@ def forward(
masks = [mask]
prob_masks = []

if (self.acceleration is not None) and (self.center_fraction is not None):
acceleration = self.acceleration
center_fraction = center_fraction
else:
if (acceleration is None) or (center_fraction is None):
raise ValueError(f"One of `acceleration` or `center_fraction` received None for a value. "
f"This should not be None when `StraightThroughPolicy` is initialized "
f"with `acceleration` or `center_fraction` with None value."
)

budget = self.num_actions / acceleration - self.num_actions * center_fraction
if isinstance(budget, float):
budget = int(budget)
Expand Down

0 comments on commit e95b890

Please sign in to comment.