Skip to content

Commit

Permalink
fixd some more bugs in dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 12, 2024
1 parent dbdff91 commit 11d36e1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
27 changes: 19 additions & 8 deletions flaxdiff/data/online_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
data_queue = Queue(16*2000)
error_queue = Queue(16*2000)


def fetch_single_image(image_url, timeout=None, retries=0):
for _ in range(retries + 1):
try:
Expand All @@ -46,11 +45,13 @@ def fetch_single_image(image_url, timeout=None, retries=0):
def map_sample(
url, caption,
image_shape=(256, 256),
timeout=15,
retries=3,
upscale_interpolation=cv2.INTER_LANCZOS4,
downscale_interpolation=cv2.INTER_AREA,
):
try:
image = fetch_single_image(url, timeout=15, retries=3) # Assuming fetch_single_image is defined elsewhere
image = fetch_single_image(url, timeout=timeout, retries=retries) # Assuming fetch_single_image is defined elsewhere
if image is None:
return

Expand Down Expand Up @@ -84,15 +85,24 @@ def map_sample(
"original_width": original_width,
})
except Exception as e:
print(f"Error in map_sample: {str(e)}")
error_queue.put({
"url": url,
"caption": caption,
"error": str(e)
})

def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
with ThreadPoolExecutor(max_workers=num_threads) as executor:
executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)

def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3):
try:
map_sample_fn = partial(map_sample, image_shape=image_shape, timeout=timeout, retries=retries)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
executor.map(map_sample_fn, batch["url"], batch['caption'])
except Exception as e:
print(f"Error in map_batch: {str(e)}")
error_queue.put({
"batch": batch,
"error": str(e)
})

def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
Expand All @@ -102,8 +112,10 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
iteration = 0
while True:
# Repeat forever
dataset = dataset.shuffle(seed=iteration)
print(f"Shuffling dataset with seed {iteration}")
# dataset = dataset.shuffle(seed=iteration)
shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
print(f"mapping {len(shards)} shards")
pool.map(map_batch_fn, shards)
iteration += 1

Expand Down Expand Up @@ -205,4 +217,3 @@ def __next__(self):

def __len__(self):
return len(self.dataset)

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.15',
version='0.1.16',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 11d36e1

Please sign in to comment.