-
Notifications
You must be signed in to change notification settings - Fork 0
/
pro_imgs.py
54 lines (40 loc) · 1.9 KB
/
pro_imgs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import base64
import os
import multiprocessing
from PIL import Image
from io import BytesIO
from tqdm import tqdm
import json
def save_image_from_base64(base64_data, fp, output_folder):
try:
image_data = base64.b64decode(base64_data)
image = Image.open(BytesIO(image_data))
image.save(os.path.join(output_folder, fp), "PNG")
print(f"Saved: {fp}")
except Exception as e:
print(f"Error saving image: {str(e)}")
def get_total(base64_image_list, fps):
output_folder = "/home/data/aigc/images"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 使用多进程保存Base64图像为PNG文件
with multiprocessing.Pool(processes=24) as pool: # 指定进程池的最大工作进程数
pool.starmap(save_image_from_base64, [(data, fp, output_folder) for data, fp in zip(base64_image_list, fps)])
def pro_total(dataset_path='/home/data/aigc/data/part-00000'):
base64_image_list = []
captions = []
with open(dataset_path, 'r') as file:
for idx, line in enumerate(tqdm(file)):
li = line.split('\t')
base64_image_list.append(li[3])
captions.append(li[4].strip('\n'))
fps = [dataset_path.split('/')[-1] + '_' + str(i) + '.png' for i in range(len(base64_image_list))]
image2caption = {k:v for k,v in zip(fps, captions)}
get_total(base64_image_list, fps)
with open('/home/data/aigc/{}.json'.format(dataset_path.split('/')[-1]), 'w') as f:
json.dump(image2caption, f, indent='\t', ensure_ascii=False)
if __name__ == "__main__":
# dataset_paths = ['/home/data/aigc/data/part-0000{}'.format(i) for i in range(10)] + ['/home/data/aigc/data/part-000{}'.format(i) for i in range(10,20)]
dataset_paths = ['/home/data/aigc/data/part-0000{}'.format(i) for i in range(1)]
for each in dataset_paths:
pro_total(each)