Skip to content

Commit

Permalink
[fix] script mostly working with torchdata, lots of comments to clear
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Aug 27, 2024
1 parent d103847 commit cd12bea
Showing 1 changed file with 58 additions and 38 deletions.
96 changes: 58 additions & 38 deletions scripts/idr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from hydra.utils import instantiate
import os
from pytorch_lightning import loggers as pl_loggers
from bioimage_embed.lightning.dataloader import DataModule2

# import fsspec
import submitit
import bioimage_embed

# import submitit
from pytorch_lightning.callbacks import ModelCheckpoint

# %%
memory = Memory(location=".", verbose=0)
Expand All @@ -31,7 +31,7 @@
# %%
# # Setup fsspec filesystem for FTP access
# fs = fsspec.filesystem("ftp", host=host, anon=True)
fs = fsspec.filesystem("ftp", host=host, anon=True)
# fs = fsspec.filesystem("ftp", host=host, anon=True)

# %% [markdown]
# # Glob pattern to match the files you're interested in
Expand Down Expand Up @@ -60,7 +60,7 @@ def get_file_list(glob_str, fs):


# %%
files = get_file_list(glob_str, fs)
# files = get_file_list(glob_str, fs)


# %%
Expand All @@ -69,10 +69,10 @@ def read_file(x):
# Attempt to open the image
print(x[0])
stream = x[1].read()
print("Valid file")
# print("Valid file")
return stream
except Exception:
print("Invalid file")
# print("Invalid file")
return None


Expand All @@ -87,33 +87,19 @@ def is_valid_image(x):
# Attempt to open the image
image = read_image(x)
image.verify() # Ensure it's a valid image
print("Valid image")
# print("Valid image")
return True
except (IOError, UnidentifiedImageError):
print("Invalid image")
# print("Invalid image")
return False


# %%
dp = (
# IterableWrapper(files)
IterableWrapper(files)
.open_files_by_fsspec(
anon=True,
protocol="ftp",
host=host,
mode="rb",
filecache={"cache_storage": "tmp/idr"},
)
# .filter(filter_fn=is_valid_file)
.map(read_file)
.filter(filter_fn=is_valid_image)
.map(lambda x: Image.open(io.BytesIO(x)))
)
def image_open(x):
return Image.open(io.BytesIO(x)).convert("RGB")

# %%
a = next(iter(dp))
print(a)

def add_label(x):
return x, 0


def train(num_gpus_per_node=1, num_nodes=1):
Expand All @@ -132,30 +118,48 @@ def train(num_gpus_per_node=1, num_nodes=1):

transform = instantiate(config.Transform())

dataset = (
datapipe = (
# IterableWrapper(files)
IterableWrapper(files)
.open_files_by_fsspec(
anon=True,
protocol="ftp",
host=host,
mode="rb",
filecache={"cache_storage": "tmp/idr"},
filecache={"cache_storage": "/tmp/idr"},
)
# .filter(filter_fn=is_valid_file)
.map(read_file)
.filter(filter_fn=is_valid_image)
.map(lambda x: Image.open(io.BytesIO(x)))
.map(lambda x: x.convert("RGB"))
.map(image_open)
.map(transform)
.map(add_label)
.set_length(len(files))
# .sharding_filter(num_shards=num_nodes, shard_id=0)
# .batch(1)
# TODO add zip_with_iter() to combine the image and the label
# .zip_with_iter()
)

# dataset = datasets.ImageFolder(transform=transform)
# # %%
# dp = (
# # IterableWrapper(files)
# IterableWrapper(files)
# .open_files_by_fsspec(
# anon=True,
# protocol="ftp",
# host=host,
# mode="rb",
# filecache={"cache_storage": "tmp/idr"},
# )
# # .filter(filter_fn=is_valid_file)
# .map(read_file)
# .filter(filter_fn=is_valid_image)
# .map(lambda x: Image.open(io.BytesIO(x)))
# )

a = next(iter(dataset))
# %%
a = next(iter(datapipe))
print(a)

# dp = Mapper(dp, lambda x: x.read())
Expand Down Expand Up @@ -317,21 +321,37 @@ def train(num_gpus_per_node=1, num_nodes=1):
# transform = instantiate(config.ATransform())
# dataset = GlobDataset(root_dir+'**/*.tif*',transform,fs=fs)
# dataset = RandomDataset(32, 64)
dataloader = config.DataLoader(dataset=dataset, num_workers=os.cpu_count())
# dataloader = config.DataLoader(
# dataset=dataset, num_workers=os.cpu_count(), batch_size=None
# )
dataloader = DataModule2(datapipe, num_workers=os.cpu_count())

# dataloader = config.DataLoader(num_workers=os.cpu_count())

assert instantiate(dataloader, batch_size=1)
# assert instantiate(dataloader, batch_size=1)
# assert dataset[0]

model = config.Model(input_dim=input_dim)

lit_model = config.LightningModel(model=model)

checkpoint = ModelCheckpoint(
monitor="val/loss",
filename="best",
save_top_k=1,
mode="min",
save_last=True,
)

wandb = pl_loggers.WandbLogger(project="idr", name="0093", log_model="all")
# wandb.watch(lit_model, log="all")
trainer = config.Trainer(
accelerator="auto",
devices=num_gpus_per_node,
num_nodes=num_nodes,
strategy="ddp",
callbacks=[],
callbacks=[checkpoint],
# callbacks=[],
# plugin=[],
logger=[wandb],
)
Expand All @@ -345,10 +365,10 @@ def train(num_gpus_per_node=1, num_nodes=1):
# breakpoint()

bie = bioimage_embed.BioImageEmbed(cfg)
# wandb.watch(bie.icfg.lit_model, log="all")
wandb.watch(bie.icfg.lit_model, log="all")
# wandb.run.define_metric("mse/val", summary="best")
# wandb.run.define_metric("loss/val.loss", summary="best")

# bie.check()
bie.train()
wandb.finish()

Expand Down

0 comments on commit cd12bea

Please sign in to comment.