Skip to content

Commit

Permalink
feat: exposed enterpolation options
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 13, 2024
1 parent 019164e commit 565475c
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 47 deletions.
122 changes: 85 additions & 37 deletions datasets/dataset preparations copy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -686,13 +686,27 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a24efac453bc492da450fd084c5cb457",
"model_id": "9c55ed16514e4886a4b30c1c7a6a87ca",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading readme: 0%| | 0.00/1.64k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "caf568ec9bde4a4db1c2a21afc0f9a03",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -702,24 +716,33 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "78d28431417b422cac6233c9252b9a24",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0/958 [00:00<?, ?files/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"leonardo = load_dataset(\"bigdata-pw/leonardo\", split='train', streaming=True)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"leonardo_100m = leonardo.shuffle().take(100_000_000)"
"leonardo = load_dataset(\"bigdata-pw/leonardo\", split=\"all\", num_proc=64)\n",
"leonardoMap = {\n",
" \"url\": \"image_url\",\n",
" \"caption\": \"caption\",\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -741,39 +764,44 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"leonardoLiked = leonardo.filter(leonardoFilter(heavyFilterMap), num_proc=120)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"final_data = mapDataset(leonardoLiked, ({\n",
" \"url\":\"url\",\n",
" \"caption\":\"text\"\n",
" },), batch_size=1000000, workers=None)\n",
"\n",
"final_data.save_to_disk(\"gs://flaxdiff-datasets-regional/datasets/leonardo-liked\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 65/100000000 [01:00<25938:14:13, 1.07it/s]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[43], line 15\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124murl\u001b[39m\u001b[38;5;124m\"\u001b[39m: item[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124murl\u001b[39m\u001b[38;5;124m'\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcaption\u001b[39m\u001b[38;5;124m\"\u001b[39m: item[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprompt\u001b[39m\u001b[38;5;124m'\u001b[39m]} \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m batch]\n\u001b[1;32m 13\u001b[0m loader \u001b[38;5;241m=\u001b[39m DataLoader(filtered_leonardo_iterator, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m, num_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m, persistent_workers\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, collate_fn\u001b[38;5;241m=\u001b[39mcollate_fn)\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m tqdm\u001b[38;5;241m.\u001b[39mtqdm(loader, total\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100_000_000\u001b[39m):\n\u001b[1;32m 16\u001b[0m filtered_leonardo\u001b[38;5;241m.\u001b[39mextend(batch)\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/tqdm/std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1181\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m obj \u001b[38;5;129;01min\u001b[39;00m iterable:\n\u001b[1;32m 1182\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m obj\n\u001b[1;32m 1183\u001b[0m \u001b[38;5;66;03m# Update and possibly print the progressbar.\u001b[39;00m\n\u001b[1;32m 1184\u001b[0m \u001b[38;5;66;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;00m\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 628\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 630\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1327\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data)\n\u001b[1;32m 1326\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1327\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1328\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1329\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m 1330\u001b[0m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1293\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1289\u001b[0m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m 1290\u001b[0m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m 1291\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1292\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1293\u001b[0m success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1294\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m 1295\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1131\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1118\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_try_get_data\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout\u001b[38;5;241m=\u001b[39m_utils\u001b[38;5;241m.\u001b[39mMP_STATUS_CHECK_INTERVAL):\n\u001b[1;32m 1119\u001b[0m \u001b[38;5;66;03m# Tries to fetch data from `self._data_queue` once for a given timeout.\u001b[39;00m\n\u001b[1;32m 1120\u001b[0m \u001b[38;5;66;03m# This can also be used as inner loop of fetching without timeout, with\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;66;03m# Returns a 2-tuple:\u001b[39;00m\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;66;03m# (bool: whether successfully get data, any: data if successful else None)\u001b[39;00m\n\u001b[1;32m 1130\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1131\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_queue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n\u001b[1;32m 1133\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1134\u001b[0m \u001b[38;5;66;03m# At timeout and error, we manually check whether any worker has\u001b[39;00m\n\u001b[1;32m 1135\u001b[0m \u001b[38;5;66;03m# failed. Note that this is the only mechanism for Windows to detect\u001b[39;00m\n\u001b[1;32m 1136\u001b[0m \u001b[38;5;66;03m# worker failures.\u001b[39;00m\n",
"File \u001b[0;32m/usr/lib/python3.10/multiprocessing/queues.py:113\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m block:\n\u001b[1;32m 112\u001b[0m timeout \u001b[38;5;241m=\u001b[39m deadline \u001b[38;5;241m-\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic()\n\u001b[0;32m--> 113\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_poll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m Empty\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_poll():\n",
"File \u001b[0;32m/usr/lib/python3.10/multiprocessing/connection.py:257\u001b[0m, in \u001b[0;36m_ConnectionBase.poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_closed()\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_readable()\n\u001b[0;32m--> 257\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_poll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/usr/lib/python3.10/multiprocessing/connection.py:424\u001b[0m, in \u001b[0;36mConnection._poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 423\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_poll\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout):\n\u001b[0;32m--> 424\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 425\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mbool\u001b[39m(r)\n",
"File \u001b[0;32m/usr/lib/python3.10/multiprocessing/connection.py:931\u001b[0m, in \u001b[0;36mwait\u001b[0;34m(object_list, timeout)\u001b[0m\n\u001b[1;32m 928\u001b[0m deadline \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic() \u001b[38;5;241m+\u001b[39m timeout\n\u001b[1;32m 930\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m--> 931\u001b[0m ready \u001b[38;5;241m=\u001b[39m \u001b[43mselector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mselect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 932\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ready:\n\u001b[1;32m 933\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [key\u001b[38;5;241m.\u001b[39mfileobj \u001b[38;5;28;01mfor\u001b[39;00m (key, events) \u001b[38;5;129;01min\u001b[39;00m ready]\n",
"File \u001b[0;32m/usr/lib/python3.10/selectors.py:416\u001b[0m, in \u001b[0;36m_PollLikeSelector.select\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 414\u001b[0m ready \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 416\u001b[0m fd_event_list \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_selector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpoll\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 417\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mInterruptedError\u001b[39;00m:\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ready\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
"249it [06:54, 1.57it/s] "
]
}
],
"source": [
"leonardo = load_dataset(\"bigdata-pw/leonardo\", split='train', streaming=True)\n",
"leonardo_100m = leonardo.shuffle()#.take(100_000_000)\n",
"\n",
"filtered_leonardo_iterator = leonardo_100m.filter(leonardoFilter(heavyFilterMap))\n",
"filtered_leonardo = []\n",
"# for sample in tqdm.tqdm(filtered_leonardo_iterator, total=100_000_000):\n",
Expand All @@ -788,10 +816,30 @@
"\n",
"loader = DataLoader(filtered_leonardo_iterator, batch_size=1000, num_workers=64, persistent_workers=True, collate_fn=collate_fn)\n",
"\n",
"for batch in tqdm.tqdm(loader, total=100_000_000//1000):\n",
"for batch in tqdm.tqdm(loader, total=100_000//1000):\n",
" filtered_leonardo.extend(batch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"191462"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(filtered_leonardo)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Loading

0 comments on commit 565475c

Please sign in to comment.