-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
141 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Snakefile | ||
|
||
from snakemake.remote import AUTO | ||
# from snakemake.remote.HTTP import RemoteProvider as HTTPRemoteProvider | ||
from snakemake.remote.zenodo import RemoteProvider | ||
from snakemake.remote.HTTP import RemoteProvider as HTTPRemoteProvider | ||
|
||
HTTP = HTTPRemoteProvider() | ||
# HTTP = HTTPRemoteProvider() | ||
import numpy as np | ||
from PIL import Image | ||
import os | ||
|
||
# zenodo = RemoteProvider(deposition=7388245) | ||
# HTTP = HTTPRemoteProvider() | ||
# HTTP.remote("mefs.tar.gz", keep_local=True) | ||
# breakpoint() | ||
|
||
# | ||
def split_masks_pil(input_file, output_dir): | ||
input_path = Path(input_file) | ||
output_dir = Path(output_dir) | ||
with Image.open(input_path) as img: | ||
# Get unique class labels present in the image | ||
class_labels = set(img.getdata()) - set([0]) | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# Create a separate binary image for each class label | ||
for label in class_labels: | ||
# Convert the mask image to a binary image (black and white) with only the current label | ||
binary_img = Image.new("1", img.size) | ||
binary_img.putdata([1 if pixel == label else 0 for pixel in img.getdata()]) | ||
|
||
# Save the binary image in the output directory | ||
output_path = os.path.join(output_dir, f"label_{label}.png") | ||
binary_img.save(output_path) | ||
|
||
|
||
def split_masks_np(input_file, output_dir): | ||
mask = np.array(Image.open(input_file)) | ||
unique_labels = np.unique(mask) | ||
class_labels = set(unique_labels) - set([0]) | ||
# Create the output directory if it doesn't exist | ||
os.makedirs(output_dir, exist_ok=True) | ||
|
||
# Loop through each unique class label and create a separate binary image | ||
for label in class_labels: | ||
binary_mask = (mask == label).astype(np.uint8) | ||
binary_img = Image.fromarray( | ||
binary_mask * 255 | ||
) # Convert to 8-bit binary image (0 or 255) | ||
output_file = os.path.join(output_dir, f"label_{label}.png") | ||
binary_img.save(output_file) | ||
|
||
|
||
split_masks = split_masks_pil | ||
|
||
|
||
def get_mask_files(wildcards): | ||
# Get the list of mask files | ||
ckpt = checkpoints.untar.get(**wildcards) | ||
mask_file, = glob_wildcards("mefs/data/processed/Control/{mask_file}.png") | ||
return expand("torchvision/Control/{mask_file}", mask_file=mask_file) | ||
# breakpoint() | ||
# return mask_files | ||
|
||
|
||
def get_mask_label_files(wildcards): | ||
# breakpoint() | ||
mask_files = get_mask_files(wildcards) | ||
# breakpoint | ||
# ckpt = checkpoints.split_masks.get(**wildcards) | ||
mask_file, label = glob_wildcards("torchvision/Control/{mask_file}/label_{label}.png") | ||
# mask_file = [s.replace('/', '_') for s in string_list] | ||
# breakpoint() | ||
return expand( | ||
"torchvision/Control/{mask_file}_{label}.png", | ||
zip, | ||
mask_file=mask_file, | ||
label=label, | ||
) | ||
|
||
|
||
rule all: | ||
input: | ||
"mefs.tar.gz", | ||
"mefs", | ||
get_mask_files, | ||
get_mask_label_files, | ||
|
||
|
||
# rule download_and_untar: | ||
# input: | ||
# file= HTTP.remote( | ||
# "https://zenodo.org/record/7388245/files/mefs.tar.gz", keep_local=False | ||
# ) | ||
# output: | ||
# "mefs.tar.gz", | ||
# shell: | ||
# "cp {input.file} {output}" | ||
|
||
|
||
checkpoint untar: | ||
input: | ||
"mefs.tar.gz", | ||
output: | ||
directory("mefs"), | ||
shell: | ||
"tar -xzf {input} -p" | ||
|
||
|
||
# New rule to split mask images into individual PNG files | ||
checkpoint split_masks: | ||
input: | ||
# checkpoint=lambda wildcards: checkpoints.get_checkpoint(wildcards, "untar"), | ||
image="mefs/data/processed/Control/{mask_file}.png", | ||
output: | ||
folder=directory("torchvision/Control/{mask_file}"), | ||
run: | ||
split_masks(input.image, output.folder) | ||
|
||
|
||
# checkpoint collate_split_masks: | ||
# input: | ||
# # checkpoint=lambda wildcards: checkpoints.get_checkpoint(wildcards, "split_masks"), | ||
# folder="torchvision/Control/{mask_file}", | ||
|
||
# output: | ||
# images="torchvision/Control/{mask_file}_{label}.png", | ||
# shell: | ||
# "cp {input.folder}/* {output.images}" | ||
|
||
checkpoint collate_split_masks: | ||
input: | ||
# checkpoint=lambda wildcards: checkpoints.get_checkpoint(wildcards, "split_masks"), | ||
image="torchvision/Control/{mask_file}/label_{label}.png", | ||
output: | ||
image="torchvision/Control/{mask_file}_{label}.png", | ||
shell: | ||
"cp {input.image} {output.image}" | ||
|