-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
50 lines (36 loc) · 1.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Entrypoint for training DETR in JAX."""
import os
import flax
import jax
import tensorflow as tf
import trainer
from absl import app, flags, logging
from clu import metric_writers
from ml_collections import config_flags
from train_lib import train_utils
logging.set_verbosity('info')
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
config_flags.DEFINE_config_file(
'config', None, 'Training configuration.', lock_config=True)
flags.DEFINE_string('workdir', None, 'Path to store checkpoints and logs.')
flax.config.update('flax_use_orbax_checkpointing', False)
FLAGS = flags.FLAGS
def main(unused_argv):
cfg = FLAGS.config
workdir = FLAGS.workdir
if os.environ.get("OMPI_COMM_WORLD_SIZE", -1) != -1:
jax.distributed.initialize()
# Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')
rng = jax.random.PRNGKey(cfg.rng_seed)
logging.info('RNG Seed: %s', rng)
data_rng, rng = jax.random.split(rng)
writer = metric_writers.AsyncWriter(
metric_writers.SummaryWriter(logdir=workdir))
dataset = train_utils.get_dataset(cfg, rng=data_rng)
trainer.train_and_evaluate(
rng=rng, dataset=dataset, config=cfg, workdir=workdir, writer=writer)
if __name__ == "__main__":
flags.mark_flags_as_required(['config', 'workdir'])
app.run(main)