Skip to content

Commit

Permalink
Update autocaption main function to accept either a dataset or directory
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Jun 4, 2024
1 parent 0b44077 commit 0c502eb
Showing 1 changed file with 14 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ def process_images(images: list[Image.Image], prompt: str, moondream, tokenizer)
return answers


def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_path: str):
def main(
prompt: str,
use_cpu: bool,
batch_size: int,
output_path: str,
image_dir: str = None,
dataset: torch.utils.data.Dataset = None,
):
device, dtype = select_device_and_dtype(use_cpu)
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")
Expand All @@ -53,8 +60,11 @@ def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_pat
moondream_model.eval()

# Prepare the dataloader.
dataset = ImageDirDataset(image_dir)
print(f"Found {len(dataset)} images in '{image_dir}'.")
if image_dir is not None:
dataset = ImageDirDataset(image_dir)
print(f"Found {len(dataset)} images in '{image_dir}'.")
if not dataset:
raise ValueError("Either 'image_dir' or 'dataset' must be provided to this function.")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False
)
Expand Down Expand Up @@ -107,4 +117,4 @@ def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_pat
)
args = parser.parse_args()

main(args.dir, args.prompt, args.cpu, args.batch_size, args.output)
main(args.prompt, args.cpu, args.batch_size, args.output, image_dir=args.dir)

0 comments on commit 0c502eb

Please sign in to comment.