-
Notifications
You must be signed in to change notification settings - Fork 6
/
simple_example.py
28 lines (24 loc) · 1018 Bytes
/
simple_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from xca.examples.arxiv200800283.example_synthesis import pattern_simulation
from xca.ml.tf.data_proc import xr_dir_to_record, _int64_feature
from xca.ml.tf.cnn import build_CNN_model, build_fusion_ensemble_model, training
from xca.ml.tf.utils import load_hyperparameters
from pathlib import Path
def main():
for system in ("BaTiO", "ADTA", "NiCoAl"):
mapping = pattern_simulation(100, system)
xr_dir_to_record(
Path("tmp") / f"{system}",
Path("tmp") / f"{system}.tfrecords",
attrs_key="input_cif",
transform=lambda x: _int64_feature(mapping[x]),
)
params = load_hyperparameters(params_file=f"{system}_training.json")
model = build_fusion_ensemble_model(
params.pop("ensemble_size", 1), build_CNN_model, **params
)
res = training(model, **params)
print(f"Results for {system}")
print(res)
model.save(str(Path("tmp") / f"{system}"))
if __name__ == "__main__":
main()