diff --git a/tools/ImgDataset2WebDatasetMS.py b/tools/ImgDataset2WebDatasetMS.py index bbe6672..03fbfbd 100644 --- a/tools/ImgDataset2WebDatasetMS.py +++ b/tools/ImgDataset2WebDatasetMS.py @@ -1,13 +1,14 @@ -# -*- coding: utf-8 -*- # @Author: Pevernow (wzy3450354617@gmail.com) # @Date: 2025/1/5 # @License: (Follow the main project) -from PIL import PngImagePlugin -PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024 # Increase maximum size for text chunks -import os import json +import os import tarfile -from PIL import Image + +from PIL import Image, PngImagePlugin + +PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024 # Increase maximum size for text chunks + def process_data(input_dir, output_tar_name="output.tar"): """ @@ -37,17 +38,12 @@ def process_data(input_dir, output_tar_name="output.tar"): with Image.open(png_filepath) as img: width, height = img.size - with open(txt_filename, 'r', encoding='utf-8') as f: + with open(txt_filename, encoding="utf-8") as f: caption_content = f.read().strip() - data = { - "file_name": filename, - "prompt": caption_content, - "width": width, - "height": height - } + data = {"file_name": filename, "prompt": caption_content, "width": width, "height": height} - with open(json_filepath, 'w', encoding='utf-8') as outfile: + with open(json_filepath, "w", encoding="utf-8") as outfile: json.dump(data, outfile, indent=4, ensure_ascii=False) print(f"Generated: {json_filename}") @@ -59,7 +55,7 @@ def process_data(input_dir, output_tar_name="output.tar"): print(f"Warning: No corresponding TXT file found for {filename}.") # Create a TAR file and include all files - with tarfile.open(output_tar_name, 'w') as tar: + with tarfile.open(output_tar_name, "w") as tar: for item in os.listdir(input_dir): item_path = os.path.join(input_dir, item) tar.add(item_path, arcname=item) # arcname maintains the relative path of the file in the tar @@ -67,7 +63,10 @@ def process_data(input_dir, output_tar_name="output.tar"): print(f"\nAll files have been packaged into: {output_tar_name}") print(f"Number of PNG images processed: {png_count}") + if __name__ == "__main__": input_directory = input("Please enter the directory path containing PNG and TXT files: ") - output_tar_filename = input("Please enter the name of the output TAR file (default is output.tar): ") or "output.tar" + output_tar_filename = ( + input("Please enter the name of the output TAR file (default is output.tar): ") or "output.tar" + ) process_data(input_directory, output_tar_filename)