Skip to content

Commit

Permalink
Changing where holodeck features are saved.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucaweihs committed Mar 7, 2024
1 parent 682f4d3 commit 3194010
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
11 changes: 8 additions & 3 deletions objathor/dataset/generate_holodeck_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def generate_features(
batch_size: int,
num_workers: int,
):
base_dir = os.path.join(base_dir, "features")
os.makedirs(base_dir, exist_ok=True)

# CLIP
device = torch.device(device)
clip_model_name = "ViT-L-14"
Expand Down Expand Up @@ -165,9 +168,11 @@ def generate_features(

pbar.update(len(batch["uid"]))

clip_img_features = torch.cat(clip_img_features, dim=0).numpy()
clip_text_features = torch.cat(clip_text_features, dim=0).numpy()
sbert_text_features = torch.cat(sbert_text_features, dim=0).numpy()
clip_img_features = torch.cat(clip_img_features, dim=0).numpy().astype("float16")
clip_text_features = torch.cat(clip_text_features, dim=0).numpy().astype("float16")
sbert_text_features = (
torch.cat(sbert_text_features, dim=0).numpy().astype("float16")
)

compress_pickle.dump(
{
Expand Down
14 changes: 10 additions & 4 deletions objathor/dataset/postprocess_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,19 @@ def filter_func(tarinfo):
return tarinfo


def create_tar_of_assets(assets_dir: str, save_dir: str):
save_path = os.path.abspath(os.path.join(save_dir, "assets.tar"))
def create_tar_of_directory_with_exclusions(dir_to_tar: str, save_dir: str):
dir_to_tar = os.path.abspath(dir_to_tar)
save_path = os.path.abspath(
os.path.join(save_dir, f"{os.path.basename(dir_to_tar)}.tar")
)

if os.path.exists(save_path):
print(f"{save_path} already exists. Skipping...")
return

cur_dir = os.getcwd()
try:
os.chdir(assets_dir) # Change to the directory where the assets are located
os.chdir(dir_to_tar) # Change to the directory where the assets are located
with tarfile.open(save_path, "w") as tar:
for root, dirs, files in tqdm.tqdm(os.walk(".")):
for file in files:
Expand All @@ -65,7 +68,7 @@ def postprocess_assets(dataset_dir: str, batch_size: int, num_workers: int):

# Create assets.tar
print("Creating assets.tar...")
create_tar_of_assets(assets_dir=assets_dir, save_dir=dataset_dir)
create_tar_of_directory_with_exclusions(dir_to_tar=assets_dir, save_dir=dataset_dir)

# Generating holodeck features
print("Generating holodeck features...")
Expand All @@ -77,6 +80,9 @@ def postprocess_assets(dataset_dir: str, batch_size: int, num_workers: int):
batch_size=batch_size if torch.cuda.is_available() else 8,
num_workers=num_workers,
)
create_tar_of_directory_with_exclusions(
dir_to_tar=os.path.join(dataset_dir, "features"), save_dir=dataset_dir
)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions objathor/dataset/upload_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def upload_dataset(dataset_dir: str, bucket_path: str):
for name in [
"annotations.json.gz",
"assets.tar",
"clip_features.pkl",
"sbert_features.pkl",
"features.tar",
]
]

Expand Down

0 comments on commit 3194010

Please sign in to comment.