From 28f1f82d7c5e2247bb99322daf58f38b99023486 Mon Sep 17 00:00:00 2001 From: NanoNabla <43477372+NanoNabla@users.noreply.github.com> Date: Wed, 5 May 2021 15:24:27 +0200 Subject: [PATCH 1/2] multi machine parallelization using horovod --- setup.py | 9 + training/coqui_stt_training/train.py | 195 +++++++++++++------- training/coqui_stt_training/util/config.py | 35 +++- training/coqui_stt_training/util/feeding.py | 29 ++- training/coqui_stt_training/util/flags.py | 2 + 5 files changed, 186 insertions(+), 84 deletions(-) diff --git a/setup.py b/setup.py index 565a657fa..76e2c9615 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,9 @@ def main(): 'tensorflow == 1.15.4' ] + horovod_pypi_dep = [ + 'horovod[tensorflow] == 0.21.3' + ] if os.environ.get('DS_NODECODER', ''): install_requires = install_requires_base else: @@ -49,6 +52,12 @@ def main(): else: install_requires = install_requires + tensorflow_pypi_dep + + if os.environ.get('DS_WITH_HOROVOD', ''): + install_requires = install_requires + horovod_pypi_dep + else: + install_requires = install_requires + setup( name='coqui_stt_training', version=version, diff --git a/training/coqui_stt_training/train.py b/training/coqui_stt_training/train.py index 77d324871..48d6519ac 100644 --- a/training/coqui_stt_training/train.py +++ b/training/coqui_stt_training/train.py @@ -414,6 +414,13 @@ def log_grads_and_vars(grads_and_vars): def train(): exception_box = ExceptionBox() + + if FLAGS.horovod: + import horovod.tensorflow as hvd + + # Create training and validation datasets + split_dataset = FLAGS.horovod + # Create training and validation datasets train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, @@ -422,10 +429,11 @@ def train(): cache_path=FLAGS.feature_cache, train_phase=True, exception_box=exception_box, - process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, + process_ahead=Config.num_devices * FLAGS.train_batch_size * 2, reverse=FLAGS.reverse_train, limit=FLAGS.limit_train, - buffering=FLAGS.read_buffer) + buffering=FLAGS.read_buffer, + split_dataset=split_dataset) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), @@ -441,10 +449,11 @@ def train(): train_phase=False, augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], exception_box=exception_box, - process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, + process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2, reverse=FLAGS.reverse_dev, limit=FLAGS.limit_dev, - buffering=FLAGS.read_buffer) for source in dev_sources] + buffering=FLAGS.read_buffer, + split_dataset=split_dataset) for source in dev_sources] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] if FLAGS.metrics_files: @@ -454,10 +463,11 @@ def train(): train_phase=False, augmentations=[NormalizeSampleRate(FLAGS.audio_sample_rate)], exception_box=exception_box, - process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, + process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2, reverse=FLAGS.reverse_dev, limit=FLAGS.limit_dev, - buffering=FLAGS.read_buffer) for source in metrics_sources] + buffering=FLAGS.read_buffer, + split_dataset=split_dataset) for source in metrics_sources] metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets] # Dropout @@ -479,20 +489,39 @@ def train(): reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction)) optimizer = create_optimizer(learning_rate_var) + + if FLAGS.horovod: + # Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size. + optimizer = create_optimizer(learning_rate_var * hvd.size()) + optimizer = hvd.DistributedOptimizer(optimizer) + else: + optimizer = create_optimizer(learning_rate_var) + # Enable mixed precision training if FLAGS.automatic_mixed_precision: log_info('Enabling automatic mixed precision training.') optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) - gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) + if FLAGS.horovod: + loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False) + gradients = optimizer.compute_gradients(loss) - # Average tower gradients across GPUs - avg_tower_gradients = average_gradients(gradients) - log_grads_and_vars(avg_tower_gradients) + tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries']) + log_grads_and_vars(gradients) + + # global_step is automagically incremented by the optimizer + global_step = tfv1.train.get_or_create_global_step() + apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step) + else: + gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) - # global_step is automagically incremented by the optimizer - global_step = tfv1.train.get_or_create_global_step() - apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) + # Average tower gradients across GPUs + avg_tower_gradients = average_gradients(gradients) + log_grads_and_vars(avg_tower_gradients) + + # global_step is automagically incremented by the optimizer + global_step = tfv1.train.get_or_create_global_step() + apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) # Summaries step_summaries_op = tfv1.summary.merge_all('step_summaries') @@ -509,18 +538,22 @@ def train(): } # Checkpointing - checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) - checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train') + if Config.is_master_process: + checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) + checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train') + + best_dev_saver = tfv1.train.Saver(max_to_keep=1) + best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') - best_dev_saver = tfv1.train.Saver(max_to_keep=1) - best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') + # Save flags next to checkpoints + if not is_remote_path(FLAGS.save_checkpoint_dir): + os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) + flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') + with open_remote(flags_file, 'w') as fout: + fout.write(FLAGS.flags_into_string()) - # Save flags next to checkpoints - if not is_remote_path(FLAGS.save_checkpoint_dir): - os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) - flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') - with open_remote(flags_file, 'w') as fout: - fout.write(FLAGS.flags_into_string()) + if FLAGS.horovod: + bcast = hvd.broadcast_global_variables(0) with tfv1.Session(config=Config.session_config) as session: log_debug('Session opened.') @@ -531,6 +564,9 @@ def train(): # Load checkpoint or initialize variables load_or_init_graph_for_training(session) + if FLAGS.horovod: + bcast.run() + def run_set(set_name, epoch, init_op, dataset=None): is_train = set_name == 'train' train_op = apply_gradient_op if is_train else [] @@ -557,12 +593,13 @@ def __call__(self, progress, data, **kwargs): data['mean_loss'] = total_loss / step_count if step_count else 0.0 return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) - prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name]) - widgets = [' | ', progressbar.widgets.Timer(), - ' | Steps: ', progressbar.widgets.Counter(), - ' | ', LossWidget()] - suffix = ' | Dataset: {}'.format(dataset) if dataset else None - pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() + if Config.is_master_process: + prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name]) + widgets = [' | ', progressbar.widgets.Timer(), + ' | Steps: ', progressbar.widgets.Counter(), + ' | ', LossWidget()] + suffix = ' | Dataset: {}'.format(dataset) if dataset else None + pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() # Initialize iterator to the appropriate dataset session.run(init_op) @@ -586,19 +623,22 @@ def __call__(self, progress, data, **kwargs): total_loss += batch_loss step_count += 1 - pbar.update(step_count) + if Config.is_master_process: + pbar.update(step_count) - step_summary_writer.add_summary(step_summary, current_step) + step_summary_writer.add_summary(step_summary, current_step) - if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: - checkpoint_saver.save(session, checkpoint_path, global_step=current_step) - checkpoint_time = time.time() + if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: + checkpoint_saver.save(session, checkpoint_path, global_step=current_step) + checkpoint_time = time.time() - pbar.finish() + if Config.is_master_process: + pbar.finish() mean_loss = total_loss / step_count if step_count > 0 else 0.0 return mean_loss, step_count - log_info('STARTING Optimization') + if Config.is_master_process: + log_info('STARTING Optimization') train_start_time = datetime.utcnow() best_dev_loss = float('inf') dev_losses = [] @@ -606,21 +646,25 @@ def __call__(self, progress, data, **kwargs): try: for epoch in range(FLAGS.epochs): # Training - log_progress('Training epoch %d...' % epoch) + if Config.is_master_process: + log_progress('Training epoch %d...' % epoch) train_loss, _ = run_set('train', epoch, train_init_op) - log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) - checkpoint_saver.save(session, checkpoint_path, global_step=global_step) + if Config.is_master_process: + log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) + checkpoint_saver.save(session, checkpoint_path, global_step=global_step) if FLAGS.dev_files: # Validation dev_loss = 0.0 total_steps = 0 for source, init_op in zip(dev_sources, dev_init_ops): - log_progress('Validating epoch %d on %s...' % (epoch, source)) + if Config.is_master_process: + log_progress('Validating epoch %d on %s...' % (epoch, source)) set_loss, steps = run_set('dev', epoch, init_op, dataset=source) dev_loss += set_loss * steps total_steps += steps - log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss)) + if Config.is_master_process: + log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss)) dev_loss = dev_loss / total_steps dev_losses.append(dev_loss) @@ -633,15 +677,17 @@ def __call__(self, progress, data, **kwargs): epochs_without_improvement = 0 # Save new best model - if dev_loss < best_dev_loss: - best_dev_loss = dev_loss - save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') - log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) + if Config.is_master_process: + if dev_loss < best_dev_loss: + best_dev_loss = dev_loss + save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') + log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) # Early stopping if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs: - log_info('Early stop triggered as the loss did not improve the last {} epochs'.format( - epochs_without_improvement)) + if Config.is_master_process: + log_info('Early stop triggered as the loss did not improve the last {} epochs'.format( + epochs_without_improvement)) break # Reduce learning rate on plateau @@ -658,26 +704,32 @@ def __call__(self, progress, data, **kwargs): # Reduce learning rate session.run(reduce_learning_rate_op) current_learning_rate = learning_rate_var.eval() - log_info('Encountered a plateau, reducing learning rate to {}'.format( - current_learning_rate)) + if Config.is_master_process: + log_info('Encountered a plateau, reducing learning rate to {}'.format( + current_learning_rate)) # Overwrite best checkpoint with new learning rate value save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') - log_info("Saved best validating model with reduced learning rate to: %s" % (save_path)) + if Config.is_master_process: + log_info("Saved best validating model with reduced learning rate to: %s" % (save_path)) if FLAGS.metrics_files: # Read only metrics, not affecting best validation loss tracking for source, init_op in zip(metrics_sources, metrics_init_ops): - log_progress('Metrics for epoch %d on %s...' % (epoch, source)) + if Config.is_master_process: + log_progress('Metrics for epoch %d on %s...' % (epoch, source)) set_loss, _ = run_set('metrics', epoch, init_op, dataset=source) - log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss)) + if Config.is_master_process: + log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss)) - print('-' * 80) + if Config.is_master_process: + print('-' * 80) except KeyboardInterrupt: pass - log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) + if Config.is_master_process: + log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) log_debug('Session closed.') @@ -960,28 +1012,29 @@ def main(_): tfv1.set_random_seed(FLAGS.random_seed) train() - if FLAGS.test_files: - tfv1.reset_default_graph() - test() + if Config.is_master_process: + if FLAGS.test_files: + tfv1.reset_default_graph() + test() - if FLAGS.export_dir and not FLAGS.export_zip: - tfv1.reset_default_graph() - export() + if FLAGS.export_dir and not FLAGS.export_zip: + tfv1.reset_default_graph() + export() - if FLAGS.export_zip: - tfv1.reset_default_graph() - FLAGS.export_tflite = True + if FLAGS.export_zip: + tfv1.reset_default_graph() + FLAGS.export_tflite = True - if listdir_remote(FLAGS.export_dir): - log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) - sys.exit(1) + if listdir_remote(FLAGS.export_dir): + log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) + sys.exit(1) - export() - package_zip() + export() + package_zip() - if FLAGS.one_shot_infer: - tfv1.reset_default_graph() - do_single_file_inference(FLAGS.one_shot_infer) + if FLAGS.one_shot_infer: + tfv1.reset_default_graph() + do_single_file_inference(FLAGS.one_shot_infer) def run_script(): diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index f3b25bb30..3fa53db92 100755 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -79,12 +79,37 @@ def initialize_globals(): # CPU device c.cpu_device = '/cpu:0' - # Available GPU devices - c.available_devices = get_available_gpus(c.session_config) + if FLAGS.horovod: + try: + import horovod.tensorflow as hvd + except ImportError as e: + print( + "Error importing Horovod. Did you installed DeepSpeech with 'DS_WITH_HOROVOD=y'? " + "If you do not want to use horovod, do not start with '--horovod=True'") + raise e + + hvd.init() + + # Pin GPU to be used to process local rank (one GPU per process) + c.session_config.gpu_options.visible_device_list = str(hvd.local_rank()) + c.num_devices = hvd.size() + c.is_master_process = True if hvd.rank() == 0 else False + else: + # Available GPU devices + c.available_devices = get_available_gpus(c.session_config) + + # If there is no GPU available, we fall back to CPU based operation + if not c.available_devices: + c.available_devices = [c.cpu_device] + + c.num_devices = len(c.available_devices) + + # If there are no horovod processes the only one should handled like horovod master + c.is_master_process = True - # If there is no GPU available, we fall back to CPU based operation - if not c.available_devices: - c.available_devices = [c.cpu_device] + # If there is no GPU available, we fall back to CPU based operation + if not c.available_devices: + c.available_devices = [c.cpu_device] if FLAGS.bytes_output_mode: c.alphabet = UTF8Alphabet() diff --git a/training/coqui_stt_training/util/feeding.py b/training/coqui_stt_training/util/feeding.py index 30a2b2f47..dc2724694 100644 --- a/training/coqui_stt_training/util/feeding.py +++ b/training/coqui_stt_training/util/feeding.py @@ -94,7 +94,8 @@ def create_dataset(sources, limit=0, exception_box=None, process_ahead=None, - buffering=1 * MEGABYTE): + buffering=1 * MEGABYTE, + split_dataset = False): epoch_counter = Counter() # survives restarts of the dataset and its generator def generate_values(): @@ -135,14 +136,26 @@ def batch_fn(sample_ids, features, features_len, transcripts): process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations) - dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box), - output_types=(tf.string, tf.float32, tf.int32, - (tf.int64, tf.int32, tf.int64), tf.float64)) - .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)) + + dataset = tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box), + output_types=(tf.string, tf.float32, tf.int32, + (tf.int64, tf.int32, tf.int64), tf.float64)) + if split_dataset: + # Using horovod Iterator.get_next() is not aware of different devices. + # A.shard(n, i) will contain all elements of A whose index mod n = i. + import horovod.tensorflow as hvd + dataset = dataset.shard(hvd.size(), hvd.rank()) + dataset = dataset.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) + if cache_path: dataset = dataset.cache(cache_path) - dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn) - .prefetch(len(Config.available_devices))) + + dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)) + if split_dataset: + dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) + else: + dataset = dataset.prefetch(Config.num_devices) + return dataset @@ -178,5 +191,5 @@ def create_batch_set(bs, criteria): ods = create_batch_set(outlier_batch_size, lambda start, end, f, fl: end - start > int(outlier_duration_ms)) dataset = nds.concatenate(ods) - dataset = dataset.prefetch(len(Config.available_devices)) + dataset = dataset.prefetch(len(Config.num_devices)) return dataset diff --git a/training/coqui_stt_training/util/flags.py b/training/coqui_stt_training/util/flags.py index 179a987de..939c39468 100644 --- a/training/coqui_stt_training/util/flags.py +++ b/training/coqui_stt_training/util/flags.py @@ -69,6 +69,8 @@ def create_flags(): f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work') f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.') + f.DEFINE_boolean('horovod', False, 'use horovod for training on multiple gpus') + # Sample limits f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit') From 5f23b13121ea358fbc0795333d2fbb3b7c5d8fd0 Mon Sep 17 00:00:00 2001 From: NanoNabla <43477372+NanoNabla@users.noreply.github.com> Date: Wed, 5 May 2021 15:43:07 +0200 Subject: [PATCH 2/2] horovod documentation --- doc/TRAINING_ADVANCED.rst | 2 ++ doc/TRAINING_HOROVOD.rst | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 doc/TRAINING_HOROVOD.rst diff --git a/doc/TRAINING_ADVANCED.rst b/doc/TRAINING_ADVANCED.rst index d49ce0baa..bf6e4a2bf 100644 --- a/doc/TRAINING_ADVANCED.rst +++ b/doc/TRAINING_ADVANCED.rst @@ -17,3 +17,5 @@ This document contains more advanced topics with regard to training models with 9. :ref:`parallel-training-optimization` 10. :ref:`data-importers` 11. :ref:`byte-output-mode` +12. :ref:`horovod-parallel-training` + diff --git a/doc/TRAINING_HOROVOD.rst b/doc/TRAINING_HOROVOD.rst new file mode 100644 index 000000000..969042d70 --- /dev/null +++ b/doc/TRAINING_HOROVOD.rst @@ -0,0 +1,22 @@ +.. _horovod-parallel-training: + +Distributed training using Horovod +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you have a capable compute architecture, it is possible to distribute the training using `Horovod `_. A fast network is recommended. +Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication. +It also offers `Gloo `_ as an easy-to-setup communication backend. + +For more information about setup or tuning of Horovod please visit `Horovod's documentation `_. + +Horovod is expected to run on heterogeneous systems (e.g. different number and model type of GPUs per machine). +However, this can cause unpredictable problems and user interaction in training code is needed. +Therefore, we do only support homogenous systems, which means same hardware and also same software configuration (OS, drivers, MPI, NCCL, TensorFlow, ...) on each machine. +The only exception is different number of GPUs per machine, since this can be controlled by ``horovodrun -H``. + +Detailed documentation how to run Horovod is provided `here `_. +The short command to train on 4 machines using 4 GPUs each: + +.. code-block:: bash + + horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python3 DeepSpeech.py --train_files [...] --horovod \ No newline at end of file