From 019164e0957f0e9fda34f808712756a31ba82c68 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Mon, 12 Aug 2024 17:27:53 +0000 Subject: [PATCH] feat: general improvements --- datasets/dataset preparations copy.ipynb | 1351 ++++++++++++++++++++++ datasets/dataset preparations.ipynb | 142 ++- flaxdiff/data/online_loader.py | 17 +- setup.py | 2 +- 4 files changed, 1473 insertions(+), 39 deletions(-) create mode 100644 datasets/dataset preparations copy.ipynb diff --git a/datasets/dataset preparations copy.ipynb b/datasets/dataset preparations copy.ipynb new file mode 100644 index 0000000..4566e90 --- /dev/null +++ b/datasets/dataset preparations copy.ipynb @@ -0,0 +1,1351 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import webdataset as wds\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import augmax\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import grain.python as pygrain\n", + "from typing import Any, Dict, List, Tuple\n", + "import numpy as np\n", + "from functools import partial\n", + "import tqdm \n", + "\n", + "import fsspec\n", + "import json\n", + "\n", + "import os\n", + "from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel\n", + "\n", + "from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk\n", + "from datasets.utils.file_utils import get_datasets_user_agent\n", + "from concurrent.futures import ThreadPoolExecutor\n", + "from functools import partial\n", + "import io\n", + "import urllib\n", + "\n", + "import PIL.Image\n", + "import cv2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "USER_AGENT = get_datasets_user_agent()\n", + "\n", + "\n", + "def fetch_single_image(image_url, timeout=None, retries=0):\n", + " for _ in range(retries + 1):\n", + " try:\n", + " request = urllib.request.Request(\n", + " image_url,\n", + " data=None,\n", + " headers={\"user-agent\": USER_AGENT},\n", + " )\n", + " with urllib.request.urlopen(request, timeout=timeout) as req:\n", + " image = PIL.Image.open(io.BytesIO(req.read()))\n", + " break\n", + " except Exception:\n", + " image = None\n", + " return image\n", + "\n", + "denormalizeImage = lambda x: (x + 1.0) * 127.5\n", + "\n", + "def plotImages(imgs, fig_size=(8, 8), dpi=100):\n", + " fig = plt.figure(figsize=fig_size, dpi=dpi)\n", + " imglen = imgs.shape[0]\n", + " for i in range(imglen):\n", + " plt.subplot(fig_size[0], fig_size[1], i + 1)\n", + " plt.imshow(jnp.astype(denormalizeImage(imgs[i, :, :, :]), jnp.uint8))\n", + " plt.axis(\"off\")\n", + " plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Filtering pipeline for various datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def dataMapper(map: Dict[str, Any]):\n", + " def _map(sample) -> Dict[str, Any]:\n", + " return {\n", + " \"url\": sample[map[\"url\"]],\n", + " \"caption\": sample[map[\"caption\"]],\n", + " }\n", + " return _map\n", + "\n", + "def imageFetcher():\n", + " def fetch_images(batch, num_threads, timeout=None, retries=0):\n", + " fetch_single_image_with_args = partial(fetch_single_image, timeout=timeout, retries=retries)\n", + " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", + " batch[\"image\"] = list(executor.map(fetch_single_image_with_args, batch[\"url\"]))\n", + " return batch\n", + " return fetch_images\n", + "\n", + "def mapDataset(dataset, args, mapper=dataMapper, workers=16, batch_size=10000, should_remove_columns=True, fn_kwargs={}):\n", + " if should_remove_columns:\n", + " remove_columns = dataset.column_names\n", + " else:\n", + " remove_columns = None\n", + " return dataset.map(mapper(*args), batched=True, batch_size=batch_size, remove_columns=remove_columns, num_proc=workers, fn_kwargs=fn_kwargs) " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e357e8fa8418439e8d2d0a8e23f3d1c5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map (num_proc=16): 0%| | 0/12096809 [00:00 value[\"max\"]:\n", + " return False\n", + " return True\n", + " return _filter\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c5d3eddced904acca1ddd5625e84d5ed", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Filter (num_proc=64): 0%| | 0/746972269 [00:00 value[\"max\"]:\n", + " return False\n", + " return True\n", + " return _filter\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 65/100000000 [01:00<25938:14:13, 1.07it/s]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[43], line 15\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124murl\u001b[39m\u001b[38;5;124m\"\u001b[39m: item[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124murl\u001b[39m\u001b[38;5;124m'\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcaption\u001b[39m\u001b[38;5;124m\"\u001b[39m: item[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m'\u001b[39m]} \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m batch]\n\u001b[1;32m 13\u001b[0m loader \u001b[38;5;241m=\u001b[39m DataLoader(filtered_leonardo_iterator, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m, num_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m, persistent_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, collate_fn\u001b[38;5;241m=\u001b[39mcollate_fn)\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m tqdm\u001b[38;5;241m.\u001b[39mtqdm(loader, total\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100_000_000\u001b[39m):\n\u001b[1;32m 16\u001b[0m filtered_leonardo\u001b[38;5;241m.\u001b[39mextend(batch)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/tqdm/std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1181\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 628\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1327\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data)\n\u001b[1;32m 1326\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1327\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1328\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1329\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m 1330\u001b[0m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1293\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1289\u001b[0m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m 1290\u001b[0m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m 1291\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1292\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1293\u001b[0m success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1294\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m 1295\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1131\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1118\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_try_get_data\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout\u001b[38;5;241m=\u001b[39m_utils\u001b[38;5;241m.\u001b[39mMP_STATUS_CHECK_INTERVAL):\n\u001b[1;32m 1119\u001b[0m \u001b[38;5;66;03m# Tries to fetch data from `self._data_queue` once for a given timeout.\u001b[39;00m\n\u001b[1;32m 1120\u001b[0m \u001b[38;5;66;03m# This can also be used as inner loop of fetching without timeout, with\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;66;03m# Returns a 2-tuple:\u001b[39;00m\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;66;03m# (bool: whether successfully get data, any: data if successful else None)\u001b[39;00m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1131\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_queue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n\u001b[1;32m 1133\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1134\u001b[0m \u001b[38;5;66;03m# At timeout and error, we manually check whether any worker has\u001b[39;00m\n\u001b[1;32m 1135\u001b[0m \u001b[38;5;66;03m# failed. Note that this is the only mechanism for Windows to detect\u001b[39;00m\n\u001b[1;32m 1136\u001b[0m \u001b[38;5;66;03m# worker failures.\u001b[39;00m\n", + "File \u001b[0;32m/usr/lib/python3.10/multiprocessing/queues.py:113\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m block:\n\u001b[1;32m 112\u001b[0m timeout \u001b[38;5;241m=\u001b[39m deadline \u001b[38;5;241m-\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic()\n\u001b[0;32m--> 113\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_poll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m Empty\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_poll():\n", + "File \u001b[0;32m/usr/lib/python3.10/multiprocessing/connection.py:257\u001b[0m, in \u001b[0;36m_ConnectionBase.poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_closed()\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_readable()\n\u001b[0;32m--> 257\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_poll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/usr/lib/python3.10/multiprocessing/connection.py:424\u001b[0m, in \u001b[0;36mConnection._poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 423\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_poll\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout):\n\u001b[0;32m--> 424\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 425\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mbool\u001b[39m(r)\n", + "File \u001b[0;32m/usr/lib/python3.10/multiprocessing/connection.py:931\u001b[0m, in \u001b[0;36mwait\u001b[0;34m(object_list, timeout)\u001b[0m\n\u001b[1;32m 928\u001b[0m deadline \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic() \u001b[38;5;241m+\u001b[39m timeout\n\u001b[1;32m 930\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 931\u001b[0m ready \u001b[38;5;241m=\u001b[39m \u001b[43mselector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 932\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ready:\n\u001b[1;32m 933\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [key\u001b[38;5;241m.\u001b[39mfileobj \u001b[38;5;28;01mfor\u001b[39;00m (key, events) \u001b[38;5;129;01min\u001b[39;00m ready]\n", + "File \u001b[0;32m/usr/lib/python3.10/selectors.py:416\u001b[0m, in \u001b[0;36m_PollLikeSelector.select\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 414\u001b[0m ready \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 416\u001b[0m fd_event_list \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_selector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpoll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 417\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mInterruptedError\u001b[39;00m:\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ready\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "filtered_leonardo_iterator = leonardo_100m.filter(leonardoFilter(heavyFilterMap))\n", + "filtered_leonardo = []\n", + "# for sample in tqdm.tqdm(filtered_leonardo_iterator, total=100_000_000):\n", + "# filtered_leonardo.append(sample)\n", + "from torch.utils.data import DataLoader\n", + "\n", + "def collate_fn(batch):\n", + " # urls = [item['url'] for item in batch]\n", + " # captions = [item['prompt'] for item in batch]\n", + " # return {\"url\": urls, \"caption\": captions}\n", + " return [{\"url\": item['url'], \"caption\": item['prompt']} for item in batch]\n", + "\n", + "loader = DataLoader(filtered_leonardo_iterator, batch_size=1000, num_workers=64, persistent_workers=True, collate_fn=collate_fn)\n", + "\n", + "for batch in tqdm.tqdm(loader, total=100_000_000//1000):\n", + " filtered_leonardo.extend(batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data Loading Experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import multiprocessing\n", + "import threading\n", + "from multiprocessing import Queue\n", + "# from arrayqueues.shared_arrays import ArrayQueue\n", + "# from faster_fifo import Queue\n", + "import time\n", + "import albumentations as A\n", + "import queue\n", + "\n", + "USER_AGENT = get_datasets_user_agent()\n", + "\n", + "data_queue = Queue(16*2000)\n", + "error_queue = Queue(16*2000)\n", + "\n", + "\n", + "def fetch_single_image(image_url, timeout=None, retries=0):\n", + " for _ in range(retries + 1):\n", + " try:\n", + " request = urllib.request.Request(\n", + " image_url,\n", + " data=None,\n", + " headers={\"user-agent\": USER_AGENT},\n", + " )\n", + " with urllib.request.urlopen(request, timeout=timeout) as req:\n", + " image = PIL.Image.open(io.BytesIO(req.read()))\n", + " break\n", + " except Exception:\n", + " image = None\n", + " return image\n", + "\n", + "def map_sample(\n", + " url, caption, \n", + " image_shape=(256, 256),\n", + " upscale_interpolation=cv2.INTER_LANCZOS4,\n", + " downscale_interpolation=cv2.INTER_AREA,\n", + "):\n", + " try:\n", + " image = fetch_single_image(url, timeout=15, retries=3) # Assuming fetch_single_image is defined elsewhere\n", + " if image is None:\n", + " return\n", + " \n", + " image = np.array(image)\n", + " original_height, original_width = image.shape[:2]\n", + " # check if the image is too small\n", + " if min(original_height, original_width) < min(image_shape):\n", + " return\n", + " # check if wrong aspect ratio\n", + " if max(original_height, original_width) / min(original_height, original_width) > 2:\n", + " return\n", + " # check if the variance is too low\n", + " if np.std(image) < 1e-4:\n", + " return\n", + " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", + " downscale = max(original_width, original_height) > max(image_shape)\n", + " interpolation = downscale_interpolation if downscale else upscale_interpolation\n", + " image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)\n", + " image = A.pad(\n", + " image,\n", + " min_height=image_shape[0],\n", + " min_width=image_shape[1],\n", + " border_mode=cv2.BORDER_CONSTANT,\n", + " value=[255, 255, 255],\n", + " )\n", + " data_queue.put({\n", + " \"url\": url,\n", + " \"caption\": caption,\n", + " \"image\": image\n", + " })\n", + " except Exception as e:\n", + " error_queue.put({\n", + " \"url\": url,\n", + " \"caption\": caption,\n", + " \"error\": str(e)\n", + " })\n", + " \n", + "def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):\n", + " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", + " executor.map(map_sample, batch[\"url\"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)\n", + " \n", + "def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):\n", + " map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)\n", + " shard_len = len(dataset) // num_workers\n", + " print(f\"Local Shard lengths: {shard_len}\")\n", + " with multiprocessing.Pool(num_workers) as pool:\n", + " iteration = 0\n", + " while True:\n", + " # Repeat forever\n", + " dataset = dataset.shuffle(seed=iteration)\n", + " shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]\n", + " pool.map(map_batch_fn, shards)\n", + " iteration += 1\n", + " \n", + "class ImageBatchIterator:\n", + " def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):\n", + " self.dataset = dataset\n", + " self.num_workers = num_workers\n", + " self.batch_size = batch_size\n", + " loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)\n", + " self.thread = threading.Thread(target=loader, args=(dataset))\n", + " self.thread.start()\n", + " \n", + " def __iter__(self):\n", + " return self\n", + " \n", + " def __next__(self):\n", + " def fetcher(_):\n", + " return data_queue.get()\n", + " with ThreadPoolExecutor(max_workers=self.batch_size) as executor:\n", + " batch = list(executor.map(fetcher, range(self.batch_size)))\n", + " return batch\n", + " \n", + " def __del__(self):\n", + " self.thread.join()\n", + " \n", + " def __len__(self):\n", + " return len(self.dataset) // self.batch_size\n", + " \n", + "def default_collate(batch):\n", + " urls = [sample[\"url\"] for sample in batch]\n", + " captions = [sample[\"caption\"] for sample in batch]\n", + " images = np.stack([sample[\"image\"] for sample in batch], axis=0)\n", + " return {\n", + " \"url\": urls,\n", + " \"caption\": captions,\n", + " \"image\": images,\n", + " }\n", + " \n", + "def dataMapper(map: Dict[str, Any]):\n", + " def _map(sample) -> Dict[str, Any]:\n", + " return {\n", + " \"url\": sample[map[\"url\"]],\n", + " \"caption\": sample[map[\"caption\"]],\n", + " }\n", + " return _map\n", + "\n", + "class OnlineStreamingDataLoader():\n", + " def __init__(\n", + " self, \n", + " dataset, \n", + " batch_size=64, \n", + " num_workers=16, \n", + " num_threads=512,\n", + " default_split=\"all\",\n", + " pre_map_maker=dataMapper, \n", + " pre_map_def={\n", + " \"url\": \"URL\",\n", + " \"caption\": \"TEXT\",\n", + " },\n", + " global_process_count=1,\n", + " global_process_index=0,\n", + " prefetch=1000,\n", + " collate_fn=default_collate,\n", + " ):\n", + " if isinstance(dataset, str):\n", + " dataset_path = dataset\n", + " print(\"Loading dataset from path\")\n", + " dataset = load_dataset(dataset_path, split=default_split)\n", + " elif isinstance(dataset, list):\n", + " if isinstance(dataset[0], str):\n", + " print(\"Loading multiple datasets from paths\")\n", + " dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]\n", + " else:\n", + " print(\"Concatenating multiple datasets\")\n", + " dataset = concatenate_datasets(dataset)\n", + " dataset = dataset.map(pre_map_maker(pre_map_def))\n", + " self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)\n", + " print(f\"Dataset length: {len(dataset)}\")\n", + " self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)\n", + " self.collate_fn = collate_fn\n", + " \n", + " # Launch a thread to load batches in the background\n", + " self.batch_queue = queue.Queue(prefetch)\n", + " \n", + " def batch_loader():\n", + " for batch in self.iterator:\n", + " self.batch_queue.put(batch)\n", + " \n", + " self.loader_thread = threading.Thread(target=batch_loader)\n", + " self.loader_thread.start()\n", + " \n", + " def __iter__(self):\n", + " return self\n", + " \n", + " def __next__(self):\n", + " return self.collate_fn(self.batch_queue.get())\n", + " # return self.collate_fn(next(self.iterator))\n", + " \n", + " def __len__(self):\n", + " return len(self.dataset) // self.batch_size\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flaxdiff.data.online_loader import OnlineStreamingDataLoader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataloader = OnlineStreamingDataLoader(\"ChristophSchuhmann/MS_COCO_2017_URL_TEXT\", batch_size=16, num_workers=16, default_split=\"train\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataloader.batch_queue.qsize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_queue.qsize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "error_queue.qsize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in tqdm.tqdm(range(0, 2000)):\n", + " batch = next(dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def parallel_loading(dataset):\n", + " dataset.map(map_batch_fn, num_proc=64, batched=True, batch_size=64, fn_kwargs={\"num_threads\": 64})\n", + " \n", + "thread = threading.Thread(target=parallel_loading, args=(mscoco_fused,))\n", + "thread.start()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from concurrent.futures import ThreadPoolExecutor\n", + "import aiohttp\n", + "from io import BytesIO\n", + "import asyncio\n", + "from PIL import Image\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class URLDataset(Dataset):\n", + " def __init__(self, data):\n", + " self.data = data\n", + " \n", + " async def fetch_image(self, url):\n", + " async with aiohttp.ClientSession() as session:\n", + " async with session.get(url) as response:\n", + " image_data = await response.read()\n", + " image = Image.open(BytesIO(image_data))\n", + " return image\n", + " \n", + " def __getitem__(self, index):\n", + " data = self.data[index]\n", + " url, caption = data['url'], data['caption']\n", + " loop = asyncio.get_event_loop()\n", + " image = loop.run_until_complete(self.fetch_image(url))\n", + " # Preprocess image and return along with the caption\n", + " image = image.resize((256, 256)) # Example resize\n", + " return image, caption\n", + " \n", + " def __len__(self):\n", + " return len(self.data)\n", + "\n", + "# Example usage\n", + "dataset = URLDataset(mscoco_fused)\n", + "data_loader = DataLoader(dataset, batch_size=256, num_workers=8, prefetch_factor=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in tqdm.tqdm(data_loader):\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomDataset(Dataset):\n", + " def __init__(self, dataset):\n", + " self.dataset = dataset\n", + " \n", + " def __len__(self):\n", + " return len(self.dataset)\n", + " \n", + " def __getitem__(self, idx):\n", + " url = self.dataset[idx]['url']\n", + " caption = self.dataset[idx]['caption']\n", + " image = fetch_single_image(url) # Assuming fetch_single_image is defined elsewhere\n", + " return {\n", + " \"url\": url,\n", + " \"caption\": caption,\n", + " \"image\": image\n", + " }\n", + "\n", + "def collate_fn(batch):\n", + " # Custom collation logic if needed\n", + " print(batch)\n", + " # urls = [item[\"url\"] for item in batch]\n", + " # fetch_single_image_with_args = partial(fetch_single_image, timeout=10, retries=3)\n", + " # with ThreadPoolExecutor(max_workers=len(batch)) as executor:\n", + " # images = list(executor.map(fetch_single_image_with_args, urls))\n", + " \n", + " # return {\n", + " # \"url\": urls,\n", + " # \"caption\": [item[\"caption\"] for item in batch],\n", + " # \"image\": images\n", + " # }\n", + " \n", + "# Assuming mscoco_fused is your dataset\n", + "dataset = CustomDataset(mscoco_fused)\n", + "data_loader = DataLoader(dataset, batch_size=512, num_workers=8, collate_fn=collate_fn, prefetch_factor=100)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in tqdm.tqdm(data_loader):\n", + " # print(i)\n", + " # break\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "queue.qsize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install arrayqueues" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with multiprocessing.Manager() as manager:\n", + "img_queue = manager.Queue()\n", + "process = multiprocessing.Process(target=parallel_image_loader, args=(mscoco_fused, img_queue, 8))\n", + "process.start()\n", + "process.join()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import multiprocessing\n", + "from multiprocessing import shared_memory\n", + "import numpy as np\n", + "from concurrent.futures import ThreadPoolExecutor\n", + "from datasets import Dataset\n", + "import threading\n", + "\n", + "def create_shared_array(shape, dtype):\n", + " \"\"\"Create a shared numpy array.\"\"\"\n", + " nbytes = np.prod(shape) * np.dtype(dtype).itemsize\n", + " shm = shared_memory.SharedMemory(create=True, size=nbytes)\n", + " array = np.ndarray(shape, dtype=dtype, buffer=shm.buf)\n", + " return shm, array\n", + "\n", + "def map_fn(url, caption, shared_array, shared_index, lock, shape, dtype):\n", + " image = fetch_single_image(url) # Assuming fetch_single_image is defined elsewhere\n", + " with lock:\n", + " index = shared_index.value\n", + " shared_array[index] = np.frombuffer(image, dtype=dtype).reshape(shape) # Store image in shared memory\n", + " shared_index.value += 1 # Move to the next index\n", + " # Save additional info (url, caption) if necessary\n", + "\n", + "def map_batch_fn(batch, shared_array, shared_index, lock, shape, dtype, num_threads=64):\n", + " with ThreadPoolExecutor(max_workers=num_threads) as executor:\n", + " executor.map(\n", + " map_fn, \n", + " batch[\"url\"], \n", + " batch['caption'], \n", + " [shared_array] * len(batch[\"url\"]), \n", + " [shared_index] * len(batch[\"url\"]), \n", + " [lock] * len(batch[\"url\"]), \n", + " [shape] * len(batch[\"url\"]), \n", + " [dtype] * len(batch[\"url\"])\n", + " )\n", + "\n", + "def parallel_image_loader(dataset: Dataset, shared_array, shared_index, lock, shape, dtype, num_workers: int = 8):\n", + " batch_len = len(dataset) // num_workers\n", + " batches = [dataset[i * batch_len:(i + 1) * batch_len] for i in range(num_workers)]\n", + " with multiprocessing.Pool(num_workers) as pool:\n", + " pool.starmap(\n", + " map_batch_fn, \n", + " [(batch, shared_array, shared_index, lock, shape, dtype) for batch in batches]\n", + " )\n", + "\n", + "class ImageBatchIterator:\n", + " def __init__(self, dataset: Dataset, num_workers: int = 8, batch_size: int = 64, image_shape=(224, 224, 3), dtype=np.uint8):\n", + " self.dataset = dataset\n", + " self.num_workers = num_workers\n", + " self.batch_size = batch_size\n", + " self.image_shape = image_shape\n", + " self.dtype = dtype\n", + " \n", + " # Create shared memory array\n", + " self.shm, self.shared_array = create_shared_array((len(dataset),) + image_shape, dtype)\n", + " self.shared_index = multiprocessing.Value('i', 0) # Shared index counter\n", + " self.lock = multiprocessing.Lock() # Lock for safe indexing\n", + " \n", + " self.thread = threading.Thread(target=parallel_image_loader, args=(\n", + " dataset, self.shared_array, self.shared_index, self.lock, image_shape, dtype, num_workers))\n", + " self.thread.start()\n", + " \n", + " def __iter__(self):\n", + " return self\n", + " \n", + " def __next__(self):\n", + " if self.shared_index.value < self.batch_size:\n", + " raise StopIteration\n", + " \n", + " batch_start = max(0, self.shared_index.value - self.batch_size)\n", + " batch_end = self.shared_index.value\n", + " batch = self.shared_array[batch_start:batch_end]\n", + " return batch\n", + " \n", + " def __del__(self):\n", + " self.thread.join()\n", + " self.shm.close()\n", + " self.shm.unlink() # Free shared memory when done\n", + " \n", + " def __len__(self):\n", + " return len(self.dataset) // self.batch_size\n", + "\n", + "# Example usage:\n", + "dataset = ImageBatchIterator(mscoco_fused, num_workers=16, batch_size=64, image_shape=(224, 224, 3))\n", + "for i in tqdm.tqdm(range(0, 100)):\n", + " batch = next(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in tqdm.tqdm(range(0, 100)):\n", + " batch = next(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/datasets/dataset preparations.ipynb b/datasets/dataset preparations.ipynb index 8f82807..d3c7fa0 100644 --- a/datasets/dataset preparations.ipynb +++ b/datasets/dataset preparations.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -255,26 +255,34 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "test = load_from_disk(\"gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, "metadata": {}, "outputs": [ { - "ename": "FileNotFoundError", - "evalue": "Couldn't find a dataset script at /home/mrwhite0racle/research/datasets/gs:/flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017/laion-aesthetics-12m+mscoco-2017.py or any data file in the same directory.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m test \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/load.py:2594\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2589\u001b[0m verification_mode \u001b[38;5;241m=\u001b[39m VerificationMode(\n\u001b[1;32m 2590\u001b[0m (verification_mode \u001b[38;5;129;01mor\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mBASIC_CHECKS) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m save_infos \u001b[38;5;28;01melse\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mALL_CHECKS\n\u001b[1;32m 2591\u001b[0m )\n\u001b[1;32m 2593\u001b[0m \u001b[38;5;66;03m# Create a dataset builder\u001b[39;00m\n\u001b[0;32m-> 2594\u001b[0m builder_instance \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset_builder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2595\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2596\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2597\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2598\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2599\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2600\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2601\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2602\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2603\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2604\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2605\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2606\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2607\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2608\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2609\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2611\u001b[0m \u001b[38;5;66;03m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[1;32m 2612\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m streaming:\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/load.py:2266\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[0;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, trust_remote_code, _require_default_config_name, **config_kwargs)\u001b[0m\n\u001b[1;32m 2264\u001b[0m download_config \u001b[38;5;241m=\u001b[39m download_config\u001b[38;5;241m.\u001b[39mcopy() \u001b[38;5;28;01mif\u001b[39;00m download_config \u001b[38;5;28;01melse\u001b[39;00m DownloadConfig()\n\u001b[1;32m 2265\u001b[0m download_config\u001b[38;5;241m.\u001b[39mstorage_options\u001b[38;5;241m.\u001b[39mupdate(storage_options)\n\u001b[0;32m-> 2266\u001b[0m dataset_module \u001b[38;5;241m=\u001b[39m \u001b[43mdataset_module_factory\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2267\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2268\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2269\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2270\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2271\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2272\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2273\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2275\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_require_default_config_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2276\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_custom_configs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2277\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2278\u001b[0m \u001b[38;5;66;03m# Get dataset builder class from the processing script\u001b[39;00m\n\u001b[1;32m 2279\u001b[0m builder_kwargs \u001b[38;5;241m=\u001b[39m dataset_module\u001b[38;5;241m.\u001b[39mbuilder_kwargs\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/datasets/load.py:1916\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, cache_dir, trust_remote_code, _require_default_config_name, _require_custom_configs, **download_kwargs)\u001b[0m\n\u001b[1;32m 1914\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e1 \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1915\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1916\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\n\u001b[1;32m 1917\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find a dataset script at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrelative_to_absolute_path(combined_path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m or any data file in the same directory.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1918\u001b[0m )\n", - "\u001b[0;31mFileNotFoundError\u001b[0m: Couldn't find a dataset script at /home/mrwhite0racle/research/datasets/gs:/flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017/laion-aesthetics-12m+mscoco-2017.py or any data file in the same directory." - ] + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['url', 'caption'],\n", + " num_rows: 15055574\n", + "})" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "test = load_dataset(\"gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017\")" + "test.shuffle()" ] }, { @@ -295,16 +303,45 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0a01395a9cce4b2aa6b692d7299fa6f1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/128 [00:00