Skip to content

Commit

Permalink
vampire snakefile
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Sep 29, 2023
1 parent 6086894 commit 1c312e0
Showing 1 changed file with 141 additions and 0 deletions.
141 changes: 141 additions & 0 deletions data/vampire/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Snakefile

from snakemake.remote import AUTO
# from snakemake.remote.HTTP import RemoteProvider as HTTPRemoteProvider
from snakemake.remote.zenodo import RemoteProvider
from snakemake.remote.HTTP import RemoteProvider as HTTPRemoteProvider

HTTP = HTTPRemoteProvider()
# HTTP = HTTPRemoteProvider()
import numpy as np
from PIL import Image
import os

# zenodo = RemoteProvider(deposition=7388245)
# HTTP = HTTPRemoteProvider()
# HTTP.remote("mefs.tar.gz", keep_local=True)
# breakpoint()

#
def split_masks_pil(input_file, output_dir):
input_path = Path(input_file)
output_dir = Path(output_dir)
with Image.open(input_path) as img:
# Get unique class labels present in the image
class_labels = set(img.getdata()) - set([0])
output_dir.mkdir(parents=True, exist_ok=True)

# Create a separate binary image for each class label
for label in class_labels:
# Convert the mask image to a binary image (black and white) with only the current label
binary_img = Image.new("1", img.size)
binary_img.putdata([1 if pixel == label else 0 for pixel in img.getdata()])

# Save the binary image in the output directory
output_path = os.path.join(output_dir, f"label_{label}.png")
binary_img.save(output_path)


def split_masks_np(input_file, output_dir):
mask = np.array(Image.open(input_file))
unique_labels = np.unique(mask)
class_labels = set(unique_labels) - set([0])
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Loop through each unique class label and create a separate binary image
for label in class_labels:
binary_mask = (mask == label).astype(np.uint8)
binary_img = Image.fromarray(
binary_mask * 255
) # Convert to 8-bit binary image (0 or 255)
output_file = os.path.join(output_dir, f"label_{label}.png")
binary_img.save(output_file)


split_masks = split_masks_pil


def get_mask_files(wildcards):
# Get the list of mask files
ckpt = checkpoints.untar.get(**wildcards)
mask_file, = glob_wildcards("mefs/data/processed/Control/{mask_file}.png")
return expand("torchvision/Control/{mask_file}", mask_file=mask_file)
# breakpoint()
# return mask_files


def get_mask_label_files(wildcards):
# breakpoint()
mask_files = get_mask_files(wildcards)
# breakpoint
# ckpt = checkpoints.split_masks.get(**wildcards)
mask_file, label = glob_wildcards("torchvision/Control/{mask_file}/label_{label}.png")
# mask_file = [s.replace('/', '_') for s in string_list]
# breakpoint()
return expand(
"torchvision/Control/{mask_file}_{label}.png",
zip,
mask_file=mask_file,
label=label,
)


rule all:
input:
"mefs.tar.gz",
"mefs",
get_mask_files,
get_mask_label_files,


# rule download_and_untar:
# input:
# file= HTTP.remote(
# "https://zenodo.org/record/7388245/files/mefs.tar.gz", keep_local=False
# )
# output:
# "mefs.tar.gz",
# shell:
# "cp {input.file} {output}"


checkpoint untar:
input:
"mefs.tar.gz",
output:
directory("mefs"),
shell:
"tar -xzf {input} -p"


# New rule to split mask images into individual PNG files
checkpoint split_masks:
input:
# checkpoint=lambda wildcards: checkpoints.get_checkpoint(wildcards, "untar"),
image="mefs/data/processed/Control/{mask_file}.png",
output:
folder=directory("torchvision/Control/{mask_file}"),
run:
split_masks(input.image, output.folder)


# checkpoint collate_split_masks:
# input:
# # checkpoint=lambda wildcards: checkpoints.get_checkpoint(wildcards, "split_masks"),
# folder="torchvision/Control/{mask_file}",

# output:
# images="torchvision/Control/{mask_file}_{label}.png",
# shell:
# "cp {input.folder}/* {output.images}"

checkpoint collate_split_masks:
input:
# checkpoint=lambda wildcards: checkpoints.get_checkpoint(wildcards, "split_masks"),
image="torchvision/Control/{mask_file}/label_{label}.png",
output:
image="torchvision/Control/{mask_file}_{label}.png",
shell:
"cp {input.image} {output.image}"

0 comments on commit 1c312e0

Please sign in to comment.