Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Withoutnoisechannel #26

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d7e467f
add Notebook and BioSRDataLoader
MichPrencipe Aug 8, 2024
6377ba8
add gitignore
MichPrencipe Aug 8, 2024
e7dc797
lowering GPU memory requirement for developement
ashesh-0 Aug 19, 2024
d5a5e67
Merge pull request #1 from MichPrencipe/downscaling
MichPrencipe Aug 21, 2024
1f64380
add install deps file
MichPrencipe Aug 21, 2024
b9460bd
Merge branch 'main' of github.com:MichPrencipe/swin2sr
MichPrencipe Aug 21, 2024
675cd3f
fixed the environment and add libraries for training
MichPrencipe Aug 21, 2024
b4447ac
create the training.py and connection to wandb
MichPrencipe Aug 22, 2024
7b33f04
add predictor class to the notebook
MichPrencipe Aug 27, 2024
863a28c
notebook update
ashesh-0 Aug 27, 2024
90b744f
train for more epochss
MichPrencipe Aug 28, 2024
0d27b75
trial with MSEloss
MichPrencipe Sep 3, 2024
3cf7e26
perform some augmentation, change the training script deleting the tr…
MichPrencipe Sep 6, 2024
874c8b0
perform some augmentation, change the training script deleting the tr…
MichPrencipe Sep 6, 2024
baacaad
stop tracking logdir files
MichPrencipe Sep 6, 2024
6c59aab
perform training with augmented data concatenated with the original data
MichPrencipe Sep 9, 2024
78a420b
add psnr in wandb charts
MichPrencipe Sep 11, 2024
fde200e
fixed name run on wandb, write a notebook to add poisson noise and ga…
MichPrencipe Sep 12, 2024
f414cf3
train with noise, modify the normalization
MichPrencipe Sep 13, 2024
fd84244
debug normalization
MichPrencipe Sep 18, 2024
8519e9c
resize() should take float as input
ashesh-0 Sep 18, 2024
9f588c1
fixed the psnr and the normalization
MichPrencipe Sep 18, 2024
7413dff
fixed the saving file
MichPrencipe Sep 19, 2024
5332310
retrained the network with norm data and non noisy
MichPrencipe Sep 20, 2024
a0cf742
train with noisy, but maybe there are some bug to fix, actually in th…
MichPrencipe Sep 20, 2024
af88640
training with 400 epochs on 768, 768 with callbacks
MichPrencipe Sep 23, 2024
2f2d6ea
try to fix noisy data
MichPrencipe Sep 23, 2024
22d6749
fix biosr_loader
MichPrencipe Sep 23, 2024
34f3115
train on non noisy data
MichPrencipe Sep 23, 2024
eaaee78
train and evaluation with noisy data
MichPrencipe Sep 23, 2024
4b7f82d
add set global seed
MichPrencipe Sep 23, 2024
2dca7f4
fix the determinism of the evaluation
MichPrencipe Sep 23, 2024
c4fadf3
correct the poisson noise factor
MichPrencipe Sep 23, 2024
5442c68
evaluation with correct level of noise
MichPrencipe Sep 23, 2024
d50cb69
train 100 epochs with 1000 poisson noise and 5000 gauss
MichPrencipe Sep 24, 2024
63c0066
train and evaluation correct with 1000 noise and 1000 gauss, without …
MichPrencipe Sep 26, 2024
83962d2
trying to add the noise in the channels
MichPrencipe Sep 26, 2024
01a2a36
tiled prediction evaluation
MichPrencipe Oct 4, 2024
a3e49fd
tiled prediction
MichPrencipe Oct 7, 2024
14d2463
fix c2 in biosrdata, fix the progress bar
MichPrencipe Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
__pycache__/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.ckpt
*.pth
# C extensions
*.so
logdir/
# Distribution / packaging
.Python
env/
venv/
ENV/
env.bak/
venv.bak/
.venv/
pip-wheel-metadata/
wheelhouse/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
debug/
.testmondata

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/
target/

# Jupyter Notebook
.ipynb_checkpoints/

# PyCharm
.idea/
*.iml

# VS Code
.vscode/

# pyenv
.python-version

# Celery
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# dotenv
.env
.env.*
.venv

# virtualenv
# Python virtual environment directories and files
venv/
env/
ENV/
.venv/
.env/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

#chkpoints
tesi/transformer/swin2sr/logdir/**/checkpoints/*.ckpt
tesi/transformer/swin2sr/logdir/wandb/*

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# End of the list

9 changes: 9 additions & 0 deletions analysis/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import glob


def get_best_checkpoint(ckpt_dir):
output = []
for filename in glob.glob(ckpt_dir + "/*_best.ckpt"):
output.append(filename)
assert len(output) == 1, '\n'.join(output)
return output[0]
110 changes: 110 additions & 0 deletions analysis/critic_notebook_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Functions used in Critic notebooks
"""
import numpy as np
import torch

from core.model_type import ModelType
from core.psnr import PSNR, RangeInvariantPsnr


def _get_critic_prediction(pred: torch.Tensor, tar: torch.Tensor, D) -> dict:
"""
Given a predicted image and a target image, here we return a per sample prediction of
the critic regarding whether they belong to real or predicted images.
Args:
pred: predicted image
tar: target image
D: discriminator model
"""
pred_label = D(pred)
tar_label = D(tar)
pred_label = torch.sigmoid(pred_label)
tar_label = torch.sigmoid(tar_label)
N = len(pred_label)
pred_label = pred_label.view(N, -1)
tar_label = tar_label.view(N, -1)
return {
'generated': {
'mu': pred_label.mean(dim=1),
'std': pred_label.std(dim=1)
},
'target': {
'mu': tar_label.mean(dim=1),
'std': tar_label.std(dim=1)
}
}


def get_critic_prediction(model, pred_normalized, target_normalized):
pred1, pred2 = pred_normalized.chunk(2, dim=1)
tar1, tar2 = target_normalized.chunk(2, dim=1)
cpred_1 = _get_critic_prediction(pred1, tar1, model.D1)
cpred_2 = _get_critic_prediction(pred2, tar2, model.D2)
return cpred_1, cpred_2


def get_mmse_dict(model, x_normalized, target_normalized, mmse_count, model_type, psnr_type='range_invariant',
compute_kl_loss=False):
assert psnr_type in ['simple', 'range_invariant']
if psnr_type == 'simple':
psnr_fn = PSNR
else:
psnr_fn = RangeInvariantPsnr

img_mmse = 0
avg_logvar = None
assert mmse_count >= 1
for _ in range(mmse_count):
recon_normalized, td_data = model(x_normalized)
ll, dic = model.likelihood(recon_normalized, target_normalized)
recon_img = dic['mean']
img_mmse += recon_img / mmse_count
if model.predict_logvar:
if avg_logvar is None:
avg_logvar = 0
avg_logvar += dic['logvar'] / mmse_count

ll, dic = model.likelihood(recon_normalized, target_normalized)
mse = (img_mmse - target_normalized) ** 2
# batch and the two channels
N = np.prod(mse.shape[:2])
rmse = torch.sqrt(torch.mean(mse.view(N, -1), dim=1))
rmse = rmse.view(mse.shape[:2])
loss_mmse = model.likelihood.log_likelihood(target_normalized, {'mean': img_mmse, 'logvar': avg_logvar})
kl_loss = None
kl_loss_channelwise = None
if compute_kl_loss:
kl_loss = model.get_kl_divergence_loss(td_data).cpu().numpy()
resN = len(td_data['kl_channelwise'])
kl_loss_channelwise = [td_data['kl_channelwise'][i].detach().cpu().numpy() for i in range(resN)]

psnrl1 = psnr_fn(target_normalized[:, 0], img_mmse[:, 0]).cpu().numpy()
psnrl2 = psnr_fn(target_normalized[:, 1], img_mmse[:, 1]).cpu().numpy()

output = {
'mmse_img': img_mmse,
'mmse_rec_loss': loss_mmse,
'img': recon_img,
'rec_loss': ll,
'rmse': rmse,
'psnr_l1': psnrl1,
'psnr_l2': psnrl2,
'kl_loss': kl_loss,
'kl_loss_channelwise': kl_loss_channelwise,
}
if model_type == ModelType.LadderVAECritic:
D_loss = model.get_critic_loss_stats(recon_img, target_normalized)['loss'].cpu().item()
cpred_1, cpred_2 = get_critic_prediction(model, recon_img, target_normalized)
critic = {
'label1': cpred_1,
'label2': cpred_2,
'D_loss': D_loss,
}
output['critic'] = critic
return output


def get_label_separated_loss(loss_tensor):
assert loss_tensor.shape[1] == 2
return -1 * loss_tensor[:, 0].mean(dim=(1, 2)).cpu().numpy(), -1 * loss_tensor[:, 1].mean(dim=(1, 2)).cpu().numpy()
35 changes: 35 additions & 0 deletions analysis/denoiser_splitter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
This is specific to the HDN => uSplit pipeline.
"""
import os

from configs.config_utils import get_configdir_from_saved_predictionfile, load_config


def get_source_channel(pred_fname):
den_config_dir1 = get_configdir_from_saved_predictionfile(pred_fname)
config_temp = load_config(den_config_dir1)
print(pred_fname, config_temp.model.denoise_channel, config_temp.data.ch1_fname, config_temp.data.ch2_fname)
if config_temp.model.denoise_channel == 'Ch1':
ch1 = config_temp.data.ch1_fname
elif config_temp.model.denoise_channel == 'Ch2':
ch1 = config_temp.data.ch2_fname
else:
raise ValueError('Unhandled channel', config_temp.model.denoise_channel)
return ch1


def whether_to_flip(ch1_fname, ch2_fname, reference_config):
"""
When one wants to get the highsnr data, then one does not know if the order of the channels is same as what uSplit predicts.
If not, then one needs to flip the channels.
"""
ch1 = get_source_channel(ch1_fname)
ch2 = get_source_channel(ch2_fname)
channels = [reference_config.data.ch1_fname, reference_config.data.ch2_fname]
assert ch1 in channels, f'{ch1} not in {channels}'
assert ch2 in channels, f'{ch2} not in {channels}'
assert ch1 != ch2, f'{ch1} and {ch2} are same'
if ch1 == reference_config.data.ch2_fname:
return True
return False
69 changes: 69 additions & 0 deletions analysis/double_dip_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os

import matplotlib.pyplot as plt
import numpy as np

from analysis.plot_utils import clean_ax
from core.psnr import RangeInvariantPsnr


def get_psnr(gt, pred):
"""
Order in the prediction is not fixed. So, we compute the psnr of each ground truth with both predictions
and then pick the correct ordering based on the psnr value.
"""
psnr0_0 = RangeInvariantPsnr(gt[0], pred[0])
psnr0_1 = RangeInvariantPsnr(gt[0], pred[1])

psnr1_0 = RangeInvariantPsnr(gt[1], pred[0])
psnr1_1 = RangeInvariantPsnr(gt[1], pred[1])
if psnr0_0 + psnr1_1 > psnr0_1 + psnr1_0:
return psnr0_0, psnr1_1
else:
return psnr0_1, psnr1_0


def step_num(fname: str) -> int:
"""
sum1_499.jpg => 499
"""
return int(fname.split('.')[0].split('_')[-1])


def get_fpath_sequence(prefix, rootdir, extension=None):
"""
Args:
prefix: file name should start with prefix
rootdir:
extension:str
"""
output = []
for fname in os.listdir(rootdir):
if prefix == fname[:len(prefix)]:
if extension is not None:
if fname[-1 * len(extension):] != extension:
continue

output.append(os.path.join(rootdir, fname))

return sorted(output, key=lambda x: step_num(os.path.basename(x)))


def show_imgs_from_np_fpaths(fpath_list, ncols=4, img_sz=5, title_list=None, preprocessing_fn=None):
nrows = int(np.ceil(len(fpath_list) / ncols))
_, ax = plt.subplots(figsize=(img_sz * ncols, nrows * img_sz), ncols=ncols, nrows=nrows)
clean_ax(ax)
if len(ax.shape) == 1:
ax = ax.reshape(1, -1)
for ridx in range(nrows):
for cidx in range(ncols):
fpath_idx = ridx * nrows + cidx
fpath = fpath_list[fpath_idx]
img = np.load(fpath)
if preprocessing_fn is not None:
img = preprocessing_fn(img)

ax[ridx, cidx].imshow(img[0])
if isinstance(title_list, list):
title = title_list[fpath_idx]
ax[ridx, cidx].set_title(title)
Loading