-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathslabpredict.py
38 lines (30 loc) · 1 KB
/
slabpredict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
slabpredict.py
--------------
Loads a SavedModel of the flux mapping network and runs inference on
the full 1024x128x1024 validation set, saving the results.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import numpy as np
import h5py
loaded = tf.saved_model.load("./trained_models/fluxmapping_SavedModel")
print('Loaded model')
print(list(loaded.signatures.keys()))
infer = loaded.signatures["serving_default"]
print(infer.structured_outputs)
SIZE = 128
Ddir = './data/'
with h5py.File(Ddir+'univ_000_real.hdf5', 'r') as hf:
sampleDM = hf['DM'][:,896:,:].astype(np.float32)
sampleFT = hf['FT'][:,896:,:].astype(np.float32)
print('Loaded samples')
dat = np.expand_dims(sampleDM, axis=0)
dat = np.expand_dims(dat, axis=-1)
print(dat.shape)
pred = infer(tf.constant(dat))
print('DONE predicting')
pred = pred['lambda'].numpy()
print(pred.shape)
np.save(Ddir+'generated_Lya_redshift.npy', pred)
print(np.mean(np.abs(sampleFT - pred[0,:,:,:,0])))