diff --git a/examples/image/map_images_from_urls.py b/examples/image/map_images_from_urls.py new file mode 100644 index 00000000..e551de3d --- /dev/null +++ b/examples/image/map_images_from_urls.py @@ -0,0 +1,16 @@ +from datasets import load_dataset +from nomic import AtlasDataset +from tqdm import tqdm + +dataset = load_dataset('ChihHsuan-Yang/Arboretum', split='train[:100000]') +ids = list(range(len(dataset))) +dataset = dataset.add_column("id", ids) + +atlas_dataset = AtlasDataset("andriy/arboretum-100k-image-url-upload", unique_id_field="id") +records = dataset.remove_columns(["photo_id"]).to_list() + +records = [record for record in tqdm(records) if record["photo_url"] is not None] +image_urls = [record.pop("photo_url") for record in records] + +atlas_dataset.add_data(data=records, blobs=image_urls) +atlas_dataset.create_index(embedding_model="nomic-embed-vision-v1.5", topic_model=False) \ No newline at end of file diff --git a/nomic/dataset.py b/nomic/dataset.py index 4906f332..50ffa395 100644 --- a/nomic/dataset.py +++ b/nomic/dataset.py @@ -1095,9 +1095,9 @@ def create_index( modality = self.meta["modality"] if modality == "image": - indexed_field = "_blob_hash" if indexed_field is not None: logger.warning("Ignoring indexed_field for image datasets. Only _blob_hash is supported.") + indexed_field = "_blob_hash" colorable_fields = [] @@ -1170,11 +1170,14 @@ def create_index( if modality == "image": if topic_model.topic_label_field is None: - print( - "You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics." - ) + if topic_model.build_topic_model: + logger.warning( + "You did not specify the `topic_label_field` option in your topic_model, your dataset will not contain auto-labeled topics." + ) + topic_model.build_topic_model = False + topic_field = None - topic_model.build_topic_model = False + else: topic_field = ( topic_model.topic_label_field if topic_model.topic_label_field != indexed_field else None @@ -1361,7 +1364,7 @@ def add_data( Args: data: A pandas DataFrame, list of dictionaries, or pyarrow Table matching the dataset schema. embeddings: A numpy array of embeddings: each row corresponds to a row in the table. Use if you already have embeddings for your datapoints. - blobs: A list of image paths, bytes, or PIL Images. Use if you want to create an AtlasDataset using image embeddings over your images. Note: Blobs are stored locally only. + blobs: A list of image paths, bytes, PIL Images, or URLs. Use if you want to create an AtlasDataset using image embeddings over your images. pbar: (Optional). A tqdm progress bar to update. """ if embeddings is not None: @@ -1408,6 +1411,7 @@ def _add_blobs( # TODO: add support for other modalities images = [] + urls = [] for uuid, blob in tqdm(zip(ids, blobs), total=len(ids), desc="Loading images"): if isinstance(blob, str) and os.path.exists(blob): # Auto resize to max 512x512 @@ -1417,6 +1421,8 @@ def _add_blobs( buffered = BytesIO() image.save(buffered, format="JPEG") images.append((uuid, buffered.getvalue())) + elif isinstance(blob, str) and (blob.startswith("http://") or blob.startswith("https://")): + urls.append((uuid, blob)) elif isinstance(blob, bytes): images.append((uuid, blob)) elif isinstance(blob, Image.Image): @@ -1428,22 +1434,40 @@ def _add_blobs( else: raise ValueError(f"Invalid blob type for {uuid}. Must be a path to an image, bytes, or PIL Image.") - batch_size = 40 - num_workers = 10 + if len(images) == 0 and len(urls) == 0: + raise ValueError("No valid images found in the blobs list.") + if len(images) > 0 and len(urls) > 0: + raise ValueError("Cannot mix local and remote blobs in the same batch.") + + if urls: + batch_size = 10 + num_workers = 10 + else: + batch_size = 40 + num_workers = 10 def send_request(i): image_batch = images[i : i + batch_size] - ids = [uuid for uuid, _ in image_batch] - blobs = [("blobs", blob) for _, blob in image_batch] + urls_batch = urls[i : i + batch_size] + + if image_batch: + blobs = [("blobs", blob) for _, blob in image_batch] + ids = [uuid for uuid, _ in image_batch] + else: + blobs = [] + ids = [uuid for uuid, _ in urls_batch] + urls_batch = [url for _, url in urls_batch] + response = requests.post( self.atlas_api_path + blob_upload_endpoint, headers=self.header, - data={"dataset_id": self.id}, + data={"dataset_id": self.id, "urls": urls_batch}, files=blobs, ) if response.status_code != 200: raise Exception(response.text) - return {uuid: blob_hash for uuid, blob_hash in zip(ids, response.json()["hashes"])} + id2hash = {uuid: blob_hash for uuid, blob_hash in zip(ids, response.json()["hashes"])} + return id2hash # if this method is being called internally, we pass a global progress bar if pbar is None: @@ -1452,6 +1476,7 @@ def send_request(i): hash_schema = pa.schema([(self.id_field, pa.string()), ("_blob_hash", pa.string())]) returned_ids = [] returned_hashes = [] + failed_ids = [] succeeded = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: @@ -1461,6 +1486,10 @@ def send_request(i): response = future.result() # add hash to data as _blob_hash for uuid, blob_hash in response.items(): + if blob_hash is None: + failed_ids.append(uuid) + continue + returned_ids.append(uuid) returned_hashes.append(blob_hash) @@ -1468,6 +1497,13 @@ def send_request(i): succeeded += len(response) pbar.update(len(response)) + # remove all rows that failed to upload + if len(failed_ids) > 0: + failed_ids_array = pa.array(failed_ids, type=pa.string()) + logger.info(f"Failed to upload {len(failed_ids)} blobs.") + logger.info(f"Filtering out {failed_ids} from the dataset.") + data = pc.filter(data, pc.invert(pc.is_in(data[self.id_field], failed_ids_array))) # type: ignore + hash_tb = pa.Table.from_pydict({self.id_field: returned_ids, "_blob_hash": returned_hashes}, schema=hash_schema) merged_data = data.join(right_table=hash_tb, keys=self.id_field) # type: ignore