diff --git a/objathor/dataset/generate_holodeck_features.py b/objathor/dataset/generate_holodeck_features.py index 815f767..6d2f16c 100644 --- a/objathor/dataset/generate_holodeck_features.py +++ b/objathor/dataset/generate_holodeck_features.py @@ -111,6 +111,9 @@ def generate_features( batch_size: int, num_workers: int, ): + base_dir = os.path.join(base_dir, "features") + os.makedirs(base_dir, exist_ok=True) + # CLIP device = torch.device(device) clip_model_name = "ViT-L-14" @@ -165,9 +168,11 @@ def generate_features( pbar.update(len(batch["uid"])) - clip_img_features = torch.cat(clip_img_features, dim=0).numpy() - clip_text_features = torch.cat(clip_text_features, dim=0).numpy() - sbert_text_features = torch.cat(sbert_text_features, dim=0).numpy() + clip_img_features = torch.cat(clip_img_features, dim=0).numpy().astype("float16") + clip_text_features = torch.cat(clip_text_features, dim=0).numpy().astype("float16") + sbert_text_features = ( + torch.cat(sbert_text_features, dim=0).numpy().astype("float16") + ) compress_pickle.dump( { diff --git a/objathor/dataset/postprocess_assets.py b/objathor/dataset/postprocess_assets.py index 85e4776..69d5593 100644 --- a/objathor/dataset/postprocess_assets.py +++ b/objathor/dataset/postprocess_assets.py @@ -33,8 +33,11 @@ def filter_func(tarinfo): return tarinfo -def create_tar_of_assets(assets_dir: str, save_dir: str): - save_path = os.path.abspath(os.path.join(save_dir, "assets.tar")) +def create_tar_of_directory_with_exclusions(dir_to_tar: str, save_dir: str): + dir_to_tar = os.path.abspath(dir_to_tar) + save_path = os.path.abspath( + os.path.join(save_dir, f"{os.path.basename(dir_to_tar)}.tar") + ) if os.path.exists(save_path): print(f"{save_path} already exists. Skipping...") @@ -42,7 +45,7 @@ def create_tar_of_assets(assets_dir: str, save_dir: str): cur_dir = os.getcwd() try: - os.chdir(assets_dir) # Change to the directory where the assets are located + os.chdir(dir_to_tar) # Change to the directory where the assets are located with tarfile.open(save_path, "w") as tar: for root, dirs, files in tqdm.tqdm(os.walk(".")): for file in files: @@ -65,7 +68,7 @@ def postprocess_assets(dataset_dir: str, batch_size: int, num_workers: int): # Create assets.tar print("Creating assets.tar...") - create_tar_of_assets(assets_dir=assets_dir, save_dir=dataset_dir) + create_tar_of_directory_with_exclusions(dir_to_tar=assets_dir, save_dir=dataset_dir) # Generating holodeck features print("Generating holodeck features...") @@ -77,6 +80,9 @@ def postprocess_assets(dataset_dir: str, batch_size: int, num_workers: int): batch_size=batch_size if torch.cuda.is_available() else 8, num_workers=num_workers, ) + create_tar_of_directory_with_exclusions( + dir_to_tar=os.path.join(dataset_dir, "features"), save_dir=dataset_dir + ) if __name__ == "__main__": diff --git a/objathor/dataset/upload_dataset.py b/objathor/dataset/upload_dataset.py index 7e3e8e9..6cfe8f0 100644 --- a/objathor/dataset/upload_dataset.py +++ b/objathor/dataset/upload_dataset.py @@ -70,8 +70,7 @@ def upload_dataset(dataset_dir: str, bucket_path: str): for name in [ "annotations.json.gz", "assets.tar", - "clip_features.pkl", - "sbert_features.pkl", + "features.tar", ] ]