Skip to content

Commit

Permalink
[ref] cleaning up PR
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Oct 1, 2024
1 parent 5f7a83f commit 40d3ee4
Show file tree
Hide file tree
Showing 19 changed files with 10 additions and 2,137 deletions.
34 changes: 0 additions & 34 deletions bioimage_embed/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,39 +39,6 @@
DEFAULT_AUGMENTATION = A.Compose(DEFAULT_AUGMENTATION_LIST)
DEFAULT_ALBUMENTATION = A.Compose(DEFAULT_AUGMENTATION_LIST)

DEFAULT_AUGMENTATION_LIST = [
# Flip the images horizontally or vertically with a 50% chance
A.OneOf(
[
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
],
p=0.5,
),
# Rotate the images by a random angle within a specified range
A.Rotate(limit=45, p=0.5),
# Randomly scale the image intensity to adjust brightness and contrast
A.RandomGamma(gamma_limit=(80, 120), p=0.5),
# Apply random elastic transformations to the images
A.ElasticTransform(
alpha=1,
sigma=50,
alpha_affine=50,
p=0.5,
),
# Shift the image channels along the intensity axis
A.ChannelShuffle(p=0.5),
# Add a small amount of noise to the images
A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
# Crop a random part of the image and resize it back to the original size
A.RandomResizedCrop(
height=512, width=512, scale=(0.9, 1.0), ratio=(0.9, 1.1), p=0.5
),
# Adjust image intensity with a specified range for individual channels
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
]

DEFAULT_AUGMENTATION = A.Compose(DEFAULT_AUGMENTATION_LIST)

class VisionWrapper:
def __init__(self, transform_dict, *args, **kwargs):
Expand All @@ -87,7 +54,6 @@ def __call__(self, image):
return None, 0



class VisionWrapperSupervised:
def __call__(self, data):
raise NotImplementedError
20 changes: 0 additions & 20 deletions bioimage_embed/lightning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,6 @@ def __call__(self, incoming):
return self.collate_filter_for_none(incoming)


# https://stackoverflow.com/questions/74931838/cant-pickle-local-object-evaluationloop-advance-locals-batch-to-device-pyto
class Collator:
def collate_filter_for_none(self, batch):
"""
Collate function that filters out None values from the batch.
Args:
batch: The batch to be filtered.
Returns:
The filtered batch.
"""
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)

def __call__(self, incoming):
# do stuff with incoming
return self.collate_filter_for_none(incoming)


class DataModule(pl.LightningDataModule):
"""
A PyTorch Lightning DataModule for handling dataset loading and splitting.
Expand Down
40 changes: 0 additions & 40 deletions bioimage_embed/lightning/tests/test_channel_aware.py

This file was deleted.

37 changes: 0 additions & 37 deletions bioimage_embed/lightning/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@
variational_loss -> loss - recon_loss
"""

# x_recon -> output of the model
# z -> latent space
# data -> input to the model
# target -> target for supervised learning
# recon_loss -> reconstruction loss
# loss -> total loss
# variational_loss -> loss - recon_loss


class AutoEncoder(pl.LightningModule):
args = argparse.Namespace(
Expand Down Expand Up @@ -66,8 +58,6 @@ def __init__(self, model, args=SimpleNamespace()):
# TODO update all models to use this for export to onxx
# self.example_input_array = torch.randn(1, *self.model.input_dim)
# self.model.train()
# keep a handle on metrics logged by the model
self.metrics = {}

def forward(self, x: torch.Tensor) -> ModelOutput:
"""
Expand Down Expand Up @@ -153,31 +143,6 @@ def eval_step(self, batch, batch_idx):
"""
return self.predict_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
# x, y = batch
model_output = self.eval_step(batch, batch_idx)
self.log_dict(
{
"loss/test": model_output.loss,
"mse/test": F.mse_loss(model_output.recon_x, model_output.data),
"recon_loss/test": model_output.recon_loss,
"variational_loss/test": model_output.loss - model_output.recon_loss,
}
)
return model_output.loss

# Fangless function to be overloaded later
def batch_to_xy(self, batch):
x, y = batch
return x, y

def eval_step(self, batch, batch_idx):
"""
This function should be overloaded in the child class to implement the evaluation logic.
"""
model_output = self.predict_step(batch, batch_idx)
return model_output

# def lr_scheduler_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
# # Implement your own logic for updating the lr scheduler
# # This method will be called at each training step
Expand Down Expand Up @@ -227,8 +192,6 @@ def log_tensorboard(self, model_output, x):
self.global_step,
)

class AE(AutoEncoder):
pass

class AE(AutoEncoder):
pass
Expand Down
97 changes: 0 additions & 97 deletions bioimage_embed/models/o2vae_shapeembed_integration.diff

This file was deleted.

12 changes: 1 addition & 11 deletions bioimage_embed/shapes/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def eval_step(self, batch, batch_idx):
[
loss_ops.diagonal_loss(),
loss_ops.symmetry_loss(),
loss_ops.non_negative_loss(),
# loss_ops.triangle_inequality(),
loss_ops.non_negative_loss(),
# loss_ops.clockwise_order_loss(),
]
)
Expand All @@ -68,16 +68,6 @@ def __init__(self, model, args=SimpleNamespace()):
super().__init__(model, args)


class MaskEmbed(MaskEmbedMixin, AutoEncoderUnsupervised):
def __init__(self, model, args=SimpleNamespace()):
super().__init__(model, args)


class MaskEmbedSupervised(MaskEmbedMixin, AutoEncoderSupervised):
def __init__(self, model, args=SimpleNamespace()):
super().__init__(model, args)


class FixedOutput(nn.Module):
def __init__(self, tensor):
super().__init__()
Expand Down
5 changes: 2 additions & 3 deletions bioimage_embed/shapes/mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ def mds(d):
:return: A matrix of x, y coordinates.
"""
n = d.size(0)
I = torch.eye(n, dtype=torch.float64)
I = torch.eye(n)
H = I - torch.ones((n, n)) / n

S = -0.5 * H @ d @ H
#eigvals, eigvecs = S.symeig(eigenvectors=True)
eigvals, eigvecs = torch.linalg.eigh(S)
eigvals, eigvecs = S.symeig(eigenvectors=True)

# Sort the eigenvalues and eigenvectors in decreasing order
idx = eigvals.argsort(descending=True)
Expand Down
32 changes: 7 additions & 25 deletions bioimage_embed/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,23 @@
from pathlib import Path
from typer.testing import CliRunner

# def test_main_creates_config():
# # Arrange
# config_path = "test_conf"
# job_name = "test_app"

runner = CliRunner()


@pytest.fixture
def config_dir():
return "test_conf"

# # Act
# main(config_path=config_path, job_name=job_name)

# # Assert
# assert os.path.exists(config_path), "Config directory was not created"
# assert os.path.isfile(os.path.join(config_path, "config.yaml")), "Config file was not created"
@pytest.fixture
def config_file():
return "config.yaml"

# # Clean up
# os.remove(os.path.join(config_path, "config.yaml"))
# os.rmdir(config_path)

# @pytest.mark.parametrize("config_path, job_name", [
# ("conf", "test_app"),
# ("another_conf", "another_job")
# ])
# def test_hydra_initializes(config_path, job_name):
# # Act
# main(config_path=config_path, job_name=job_name)
@pytest.fixture
def config_path(config_dir, config_file):
return Path(config_dir).joinpath(config_file)

# # Assert
# # Here you can assert specifics about the cfg object if needed.
# # Since main does not return anything, you might need to adjust
# # the main function to return the cfg for more thorough testing.

@pytest.fixture
def config_directory_setup(config_dir, config_file, config_path):
Expand Down Expand Up @@ -84,7 +67,6 @@ def test_get_default_config(cfg):
# cfg.recipe.max_epochs = 1



# def test_cli():
# # This test checks if the CLI correctly handles the dataset target input
# result = runner.invoke(app, ["bie_train", "--dataset-target", "bioimage_embed.datasets.FakeImageFolder"])
Expand Down
Loading

0 comments on commit 40d3ee4

Please sign in to comment.