-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/jepa #409
Feature/jepa #409
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me...Left some minor comments again
y = self.R.randint(0, self.num_patches[-2] - height + 1) | ||
# add block to mask | ||
if self.spatial_dims == 3: | ||
mask[:, y : y + height, x : x + width] = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you not randomly sample dims in Z here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no - I'm basing this off of v-jepa (which has masks that are the same over time) with the idea that 3d images have similar spatial redundancy to natural videos
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha
target_masks = self.get_predict_masks(source.shape[0], device=source.device) | ||
|
||
# mean across patches, no cls token to remove | ||
source_embeddings = self.encoder(source) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the encoder here a vit? Does it handle the patchifying etc? Is the source here the small mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the vit encoder does patchifying. source is an image. This is taking all the patches from the source image and predicting the embeddings of all the patches from the target image.
def get_context_embeddings(self, x, mask): | ||
# mask context pre-embedding to prevent leakage of target information | ||
context_patches, _, _, _ = self.encoder.patchify(x, 0) | ||
context_patches = take_indexes(context_patches, mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take_indexes seems to be repeating indices along a channel dimension. just want to check that this is expected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this is just applying the indices across all channels of the tensor (tokens x batch x embedding_dim)
cyto_dl/nn/encoder_decoder.py
Outdated
if ckpt is not None: | ||
state_dict = torch.load(ckpt)["state_dict"] | ||
state_dict = { | ||
k.replace("backbone.", ""): v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment about why this is necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking this guy isn't ready for prime time yet
cyto_dl/nn/vits/jepa.py
Outdated
) | ||
|
||
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) | ||
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches), 1, emb_dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe use the general pos embedding function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated!
general comment: I think having a quick lookup table of the different configs you are adding would be great. Something like this - https://github.com/MMV-Lab/mmv_im2im/blob/main/tutorials/example_by_use_case.md |
agreed, I'll do a separate pr with that for all the models |
What does this PR do?
Add Joint Embedding Predictive Architecture and Image world Model (draft) infrastructure.
Fixes #<issue_number>
Before submitting
pytest
command?pre-commit run -a
command?Did you have fun?
Make sure you had fun coding 🙃