diff --git a/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py b/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py index 301323ff..a3e5c99c 100644 --- a/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py +++ b/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py @@ -32,7 +32,13 @@ def process_images(images: list[Image.Image], prompt: str, moondream, tokenizer) return answers -def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_path: str): +def main( + prompt: str, + use_cpu: bool, + batch_size: int, + output_path: str, + dataset: torch.utils.data.Dataset, +): device, dtype = select_device_and_dtype(use_cpu) print(f"Using device: {device}") print(f"Using dtype: {dtype}") @@ -52,9 +58,6 @@ def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_pat ).to(device=device, dtype=dtype) moondream_model.eval() - # Prepare the dataloader. - dataset = ImageDirDataset(image_dir) - print(f"Found {len(dataset)} images in '{image_dir}'.") data_loader = torch.utils.data.DataLoader( dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False ) @@ -107,4 +110,8 @@ def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_pat ) args = parser.parse_args() - main(args.dir, args.prompt, args.cpu, args.batch_size, args.output) + # Prepare the dataset. + dataset = ImageDirDataset(args.dir) + print(f"Found {len(dataset)} images in '{args.dir}'.") + + main(args.prompt, args.cpu, args.batch_size, args.output, dataset)