Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use h5py for output data writing and consolidation to reduce memory footprint #10

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
install_requires=[
'tensorflow-gpu<1.14.0', # https://github.com/IDSIA/sacred/issues/493
'numpy',
'rinokeras==1.1.1',
'rinokeras==1.1.2',
'biopython',
'sacred',
'table_logger',
Expand Down
60 changes: 34 additions & 26 deletions tape/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import os
import shutil
import pickle as pkl
import uuid

import h5py
import tensorflow as tf

from tape.tasks import TaskBuilder, Task, AbstractLanguageModelingTask
Expand Down Expand Up @@ -207,31 +209,37 @@ def cleanup_folders(outdir: str, model, tasks, debug):


def consolidate_data(outfile, include_hidden: bool = False):
"""
Turn batched h5 output file into flat h5 file
"""

with open(outfile, 'rb') as f:
outputs = pkl.load(f)

data = defaultdict(list) # type: ignore

for output in outputs:
output = output[0]
length = output['protein_length']
for key, protein_batch in output.items():
for protein_length, protein_data in zip(length, protein_batch):
if np.isscalar(protein_data):
data[key].append(protein_data)
elif protein_data.ndim == 1 and protein_data.dtype in [np.float32, np.float64]:
data[key].append(protein_data)
else:
data[key].append(protein_data[:protein_length])

data = dict(data)

if not include_hidden:
del data['encoder_output']

with open(outfile, 'wb') as f:
pkl.dump(data, f)
tmp_id = uuid.uuid1().hex # just in case there's some weirdness
tmp_filename = 'outputs_tmp_{}.h5'.format(tmp_id)
i = 0
with h5py.File(outfile, 'r') as f, h5py.File(tmp_filename, 'w') as f_out:
for key in f.keys(): # iterate over all batches
output = f[key]
length = output['protein_length'][()]
n_seqs = len(length)
for key, protein_batch in output.items():
protein_batch = protein_batch[()]
# iterate over all proteins in the batch
for index, protein_length, protein_data in zip(range(i, i+n_seqs), length, protein_batch):
try:
grp = f_out[str(index)]
except KeyError:
grp = f_out.create_group(str(index))
if np.isscalar(protein_data):
grp.create_dataset(key, data=protein_data)
elif protein_data.ndim == 1 and protein_data.dtype in [np.float32, np.float64]:
grp.create_dataset(key, data=protein_data)
else:
# truncate by length of the sequence to remove padding
grp.create_dataset(key, data=protein_data[:protein_length])
i += n_seqs

# be careful, this could take up many GB of disk space! (especially for the LSTM)
os.replace(tmp_filename, outfile)


@proteins.command
Expand Down Expand Up @@ -270,9 +278,9 @@ def eval(_run, _config, tasks: Union[str, List[str]], model: str):
experiment.distribution_strategy, task_model, _config['load_task_from'])

task_dir = os.path.dirname(_config['load_task_from'])
outfile = os.path.join(task_dir, 'outputs.pkl')
outfile = os.path.join(task_dir, 'outputs.h5')
print('Saving outputs to {}'.format(outfile))
test_metrics = test_graph.run_epoch(save_outputs=outfile)
test_metrics = test_graph.run_epoch(save_outputs=outfile, save_format='h5')
print(test_metrics.get_average())
consolidate_data(outfile, include_hidden=True)

Expand Down