-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/hiera #418
Feature/hiera #418
Changes from 29 commits
6ed231d
004b823
0c607d9
2c0c275
50646ae
22ddadc
693c514
913416a
a561aa4
e5591d4
9f6591c
a773f3b
70a0437
23f4b94
357166f
a12d8cd
f652138
7ee08e3
71b8096
038d573
5098732
453f0e5
ec07e39
128baa5
3c97933
1d4a76d
663e52c
30a0526
e6b58df
84717b7
316a93a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# @package _global_ | ||
# to execute this experiment run: | ||
# python train.py experiment=example | ||
defaults: | ||
- override /data: im2im/mae.yaml | ||
- override /model: im2im/hiera.yaml | ||
- override /callbacks: default.yaml | ||
- override /trainer: gpu.yaml | ||
- override /logger: csv.yaml | ||
|
||
# all parameters below will be merged with parameters from default configurations set above | ||
# this allows you to overwrite only specified parameters | ||
|
||
tags: ["dev"] | ||
seed: 12345 | ||
|
||
experiment_name: YOUR_EXP_NAME | ||
run_name: YOUR_RUN_NAME | ||
|
||
# only source_col is needed for masked autoencoder | ||
source_col: raw | ||
spatial_dims: 3 | ||
raw_im_channels: 1 | ||
|
||
trainer: | ||
max_epochs: 100 | ||
gradient_clip_val: 10 | ||
|
||
data: | ||
path: ${paths.data_dir}/example_experiment_data/segmentation | ||
cache_dir: ${paths.data_dir}/example_experiment_data/cache | ||
batch_size: 1 | ||
_aux: | ||
# 2D | ||
# patch_shape: [16, 16] | ||
# 3D | ||
patch_shape: [16, 16, 16] | ||
|
||
callbacks: | ||
# prediction | ||
# saving: | ||
# _target_: cyto_dl.callbacks.ImageSaver | ||
# save_dir: ${paths.output_dir} | ||
# save_every_n_epochs: ${model.save_images_every_n_epochs} | ||
# stages: ["predict"] | ||
# save_input: False | ||
# training | ||
saving: | ||
_target_: cyto_dl.callbacks.ImageSaver | ||
save_dir: ${paths.output_dir} | ||
save_every_n_epochs: ${model.save_images_every_n_epochs} | ||
stages: ["train", "test", "val"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
_target_: cyto_dl.models.im2im.MultiTaskIm2Im | ||
|
||
save_images_every_n_epochs: 1 | ||
save_dir: ${paths.output_dir} | ||
|
||
x_key: ${source_col} | ||
|
||
backbone: | ||
_target_: cyto_dl.nn.vits.mae.HieraMAE | ||
spatial_dims: ${spatial_dims} | ||
patch_size: 2 # patch_size* num_patches should be your patch shape | ||
num_patches: 8 # patch_size * num_patches = img_shape | ||
num_mask_units: 4 #img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clarify what a mask unit is here? |
||
emb_dim: 4 | ||
architecture: | ||
# mask_unit_attention blocks - attention is only done within a mask unit and not across mask units | ||
# the total amount of q_stride across the architecture must be less than the number of patches per mask unit | ||
- repeat: 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is repeat? |
||
q_stride: 2 | ||
num_heads: 1 | ||
- repeat: 1 | ||
q_stride: 1 | ||
num_heads: 2 | ||
# self attention transformer - attention is done across all patches, irrespective of which mask unit they're in | ||
- repeat: 2 | ||
num_heads: 4 | ||
self_attention: True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so last layer is global attention and first 2 layers are local attention? Is 3 layers the recommended hierarchy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct. 3 layers is small enough to test quickly. All of the models with unit tests are tiny by default in the configs and I have somewhere in the docs that you should increase the model size if you want good performance. |
||
decoder_layer: 1 | ||
decoder_dim: 16 | ||
mask_ratio: 0.66666666666 | ||
context_pixels: 3 | ||
use_crossmae: True | ||
|
||
task_heads: ${kv_to_dict:${model._aux._tasks}} | ||
|
||
optimizer: | ||
generator: | ||
_partial_: True | ||
_target_: torch.optim.AdamW | ||
weight_decay: 0.05 | ||
|
||
lr_scheduler: | ||
generator: | ||
_partial_: True | ||
_target_: torch.optim.lr_scheduler.OneCycleLR | ||
max_lr: 0.0001 | ||
epochs: ${trainer.max_epochs} | ||
steps_per_epoch: 1 | ||
pct_start: 0.1 | ||
|
||
inference_args: | ||
sw_batch_size: 1 | ||
roi_size: ${data._aux.patch_shape} | ||
overlap: 0 | ||
progress: True | ||
mode: "gaussian" | ||
|
||
_aux: | ||
_tasks: | ||
- - ${source_col} | ||
- _target_: cyto_dl.nn.head.mae_head.MAEHead | ||
loss: | ||
postprocess: | ||
input: | ||
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel | ||
rescale_dtype: numpy.uint8 | ||
prediction: | ||
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel | ||
rescale_dtype: numpy.uint8 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .cross_mae import CrossMAE_Decoder | ||
from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT | ||
from .decoder import CrossMAE_Decoder, MAE_Decoder | ||
from .encoder import HieraEncoder, MAE_Encoder | ||
from .mae import MAE, HieraMAE | ||
from .seg import Seg_ViT, SuperresDecoder | ||
from .utils import take_indexes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be image shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be a list for ZYX patch size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes - the terminology is confusing here haha. "patch" = the small crop extracted from your original image, but "patch" is also the tokenized component of the image fed into the network. The patch size can be either an int (repeated for each spatial dim) or a list of size spatial_dims