-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatadownloader.py
57 lines (43 loc) · 1.83 KB
/
datadownloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Script to download the data used for the SBUnfold paper. Same data as Omnifold paper.
Downloads from here https://zenodo.org/records/3548091
The first part of the script downloads all parts of the data in a lot of *.npz files.
The second part merges them into one file per dataset.
Need to set the save_dir variable at the top. The rest should run then.
The "while" and "try" loop looks a bit scuffed. The download sometimes fails, this restarts it at the last array.
"""
import energyflow
import numpy as np
import os
save_dir = '/remote/gpu07/huetsch/data'
datasets = ['Herwig', 'Pythia21', 'Pythia25', 'Pythia26']
for dataset in datasets:
finished = False
while not finished:
try:
energyflow.zjets_delphes.load(dataset, num_data=-1, pad=False, cache_dir=save_dir,
source='zenodo', which='all',
include_keys=None, exclude_keys=None)
finished=True
except:
print("Failed", dataset)
data_dir = os.path.join(save_dir, 'datasets', 'ZjetsDelphes')
all_files = os.listdir(data_dir)
all_datasets = ['Herwig', 'Pythia21', 'Pythia25', 'Pythia26']
for dataset in all_datasets:
out_dict = {}
outfile = os.path.join(data_dir, dataset + "_full.npz")
dataset_files = [os.path.join(data_dir, file) for file in all_files if dataset in file]
assert len(dataset_files) == 17, dataset_files
data = [np.load(file) for file in dataset_files]
assert len(data) == 17
keys = data[0].files
for key in keys:
if "particles" in key:
continue
placeholder = []
for i in range(17):
placeholder.append(data[i][key])
out_dict[key] = np.concatenate(placeholder, axis=0)
with open(outfile, "wb") as f:
np.savez(f, out_dict)