Skip to content

Commit

Permalink
replace all function names
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Morris committed Aug 23, 2024
1 parent 84717b7 commit 316a93a
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions cyto_dl/image/transforms/generate_jepa_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from monai.transforms import RandomizableTransform
from skimage.segmentation import find_boundaries

from cyto_dl.nn.vits.utils import validate_spatial_dims
from cyto_dl.nn.vits.utils import match_tuple_dimensions


class JEPAMaskGenerator(RandomizableTransform):
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(
"""
assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"

num_patches = validate_spatial_dims(spatial_dims, [num_patches])[0]
num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0]
assert mask_size * max(block_aspect_ratio) < min(
num_patches[-2:]
), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches"
Expand Down
4 changes: 2 additions & 2 deletions cyto_dl/nn/vits/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from cyto_dl.nn.vits.blocks import CrossAttentionBlock
from cyto_dl.nn.vits.utils import (
get_positional_embedding,
match_tuple_dimensions,
take_indexes,
validate_spatial_dims,
)


Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.
"""
super().__init__()
num_patches, patch_size = validate_spatial_dims(spatial_dims, [num_patches, patch_size])
num_patches, patch_size = match_tuple_dimensions(spatial_dims, [num_patches, patch_size])

self.has_cls_token = has_cls_token

Expand Down
6 changes: 3 additions & 3 deletions cyto_dl/nn/vits/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

from cyto_dl.nn.vits.decoder import CrossMAE_Decoder, MAE_Decoder
from cyto_dl.nn.vits.encoder import HieraEncoder, MAE_Encoder
from cyto_dl.nn.vits.utils import validate_spatial_dims
from cyto_dl.nn.vits.utils import match_tuple_dimensions


class MAE_Base(torch.nn.Module, ABC):
def __init__(
self, spatial_dims, num_patches, patch_size, mask_ratio, features_only, context_pixels
):
super().__init__()
num_patches, patch_size, context_pixels = validate_spatial_dims(
num_patches, patch_size, context_pixels = match_tuple_dimensions(
spatial_dims, [num_patches, patch_size, context_pixels]
)

Expand Down Expand Up @@ -213,7 +213,7 @@ def __init__(
features_only=features_only,
context_pixels=context_pixels,
)
num_mask_units = validate_spatial_dims(self.spatial_dims, [num_mask_units])[0]
num_mask_units = match_tuple_dimensions(self.spatial_dims, [num_mask_units])[0]

self._encoder = HieraEncoder(
num_patches=self.num_patches,
Expand Down
4 changes: 2 additions & 2 deletions cyto_dl/nn/vits/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from cyto_dl.nn.vits.blocks import CrossAttentionBlock
from cyto_dl.nn.vits.utils import (
get_positional_embedding,
match_tuple_dimensions,
take_indexes,
validate_spatial_dims,
)


Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
]
)

num_patches = validate_spatial_dims(spatial_dims, [num_patches])[0]
num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0]

self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = get_positional_embedding(
Expand Down

0 comments on commit 316a93a

Please sign in to comment.