Skip to content

Commit

Permalink
Option to log aux data
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Feb 15, 2024
1 parent f469f2a commit f736f66
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
1 change: 1 addition & 0 deletions direct/nn/vsharp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class VSharpNetConfig(ModelConfig):
initializer_activation: ActivationType = ActivationType.PRELU
conv_modulation: ModConvType = ModConvType.NONE
aux_in_features: int = 2
log_aux: bool = False
fc_hidden_features: Optional[tuple[int]] = None
fc_groups: int = 1
fc_activation: ModConvActivation = ModConvActivation.SIGMOID
Expand Down
2 changes: 1 addition & 1 deletion direct/nn/vsharp/vsharp.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def __init__(
# pylint: disable=too-many-locals
super().__init__()
for extra_key in kwargs:
if extra_key != "model_name" and not extra_key.startswith("image_"):
if extra_key != "model_name" and extra_key != "log_aux" and not extra_key.startswith("image_"):
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
self.num_steps = num_steps
self.num_steps_dc_gd = num_steps_dc_gd
Expand Down
4 changes: 4 additions & 0 deletions direct/nn/vsharp/vsharp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ def _do_iteration(

with autocast(enabled=self.mixed_precision):
if self.cfg.model.conv_modulation != ModConvType.NONE: # type: ignore
if self.cfg.model.log_aux:
data["center_fraction"] = data["center_fraction"] * 100
data["auxiliary_data"] = torch.cat([data["acceleration"], data["center_fraction"]], 1)
if self.cfg.model.log_aux:
data["auxiliary_data"] = data["auxiliary_data"].log()

output_images, output_kspace = self.forward_function(data)
output_images = [T.modulus_if_complex(_, complex_axis=self._complex_dim) for _ in output_images]
Expand Down

0 comments on commit f736f66

Please sign in to comment.