From 9f489bf0bea014bc0345c6ff0d82269489c184be Mon Sep 17 00:00:00 2001 From: Zeya Wang Date: Thu, 16 Jul 2020 00:30:41 -0400 Subject: [PATCH] Update readme and example script (#19) --- .gitignore | 1 - README.md | 9 +- docs/usage/tutorials/getting-started.md | 2 +- examples/benchmark/README.md | 8 +- examples/benchmark/bert.py | 17 +- examples/benchmark/utils/logs/cloud_lib.py | 34 ++ examples/benchmark/utils/logs/hooks.py | 130 ++++++ examples/benchmark/utils/logs/hooks_helper.py | 172 +++++++ examples/benchmark/utils/logs/logger.py | 423 ++++++++++++++++++ examples/benchmark/utils/logs/metric_hook.py | 97 ++++ .../benchmark/utils/logs/mlperf_helper.py | 192 ++++++++ examples/benchmark/utils/optimization.py | 195 -------- 12 files changed, 1061 insertions(+), 219 deletions(-) create mode 100644 examples/benchmark/utils/logs/cloud_lib.py create mode 100644 examples/benchmark/utils/logs/hooks.py create mode 100644 examples/benchmark/utils/logs/hooks_helper.py create mode 100644 examples/benchmark/utils/logs/logger.py create mode 100644 examples/benchmark/utils/logs/metric_hook.py create mode 100644 examples/benchmark/utils/logs/mlperf_helper.py delete mode 100644 examples/benchmark/utils/optimization.py diff --git a/.gitignore b/.gitignore index 81bbca4..316008c 100644 --- a/.gitignore +++ b/.gitignore @@ -87,7 +87,6 @@ tags # project-specifc **/data -**/logs* docs/_build/ docs/api/ docs/symlink* diff --git a/README.md b/README.md index 8db106d..7a38d78 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@

-[![pipeline status](https://img.shields.io/badge/dynamic/json?url=https://jenkins.petuum.io/job/AutoDist/job/master/lastCompletedBuild/api/json&label=build&query=$.result&color=informational)](https://github.com/petuum/autodist/commits/master) -[![coverage report](https://img.shields.io/badge/dynamic/json?url=https://jenkins.petuum.io/job/AutoDist/job/master/lastSuccessfulBuild/artifact/coverage-report/jenkinscovdata.json&label=coverage&query=$.total_coverage_pct&color=9cf)](https://github.com/petuum/autodist/commits/master) +[![pipeline status](https://img.shields.io/badge/dynamic/json?url=https://jenkins.petuum.io/job/AutoDist/job/master/lastCompletedBuild/api/json&label=build&query=$.result&color=informational)](https://jenkins.petuum.io/job/AutoDist/job/master/) +[![coverage report](https://img.shields.io/badge/dynamic/json?url=https://jenkins.petuum.io/job/AutoDist/job/master/lastSuccessfulBuild/artifact/coverage-report/jenkinscovdata.json&label=coverage&query=$.total_coverage_pct&color=green)](https://jenkins.petuum.io/job/AutoDist/job/master/lastSuccessfulBuild/artifact/) +[![pypi version](https://img.shields.io/pypi/v/autodist?color=9cf)](https://pypi.org/project/autodist/) [Documentation](https://petuum.github.io/autodist) | [Examples](https://github.com/petuum/autodist/tree/master/examples/benchmark) @@ -11,8 +12,6 @@ AutoDist provides a user-friendly interface to distribute the training of a wide variety of deep learning models across many GPUs with scalability and minimal code change. -AutoDist has been tested with TensorFlow versions 1.15 through 2.1. - ## Introduction Different from specialized distributed ML systems, AutoDist is created to speed up a broad range of DL models with excellent all-around performance. AutoDist achieves this goal by: @@ -29,6 +28,8 @@ for all-level users.

+For a closer look at the performance, please refer to our [doc](https://petuum.github.io/autodist/usage/performance.html). + ## Using AutoDist Installation: diff --git a/docs/usage/tutorials/getting-started.md b/docs/usage/tutorials/getting-started.md index 1ea0928..614329e 100644 --- a/docs/usage/tutorials/getting-started.md +++ b/docs/usage/tutorials/getting-started.md @@ -7,7 +7,7 @@ We recommended reviewing the [TensorFlow Quickstart Guide](https://www.tensorflo particularly the difference between eager and graph mode. If you can run the [Quickstart](https://www.tensorflow.org/tutorials/quickstart/advanced) properly, you can use the same environment to follow this tutorial. -AutoDist currently supports `Python>=3.6` with `tensorflow>=1.15, <=2.1`. Install the downloaded wheel file by +AutoDist currently supports `Python>=3.6` with `tensorflow>=1.15, <=2.2`. Install the downloaded wheel file by ```bash pip install autodist diff --git a/examples/benchmark/README.md b/examples/benchmark/README.md index 54dbbd1..bd21337 100644 --- a/examples/benchmark/README.md +++ b/examples/benchmark/README.md @@ -9,19 +9,19 @@ The instruction for generating the tfrecord data for ImageNet can be found follo ``` # You can set cnn models from vgg16, resnet101, densenet121, inceptionv3 export CNN_MODEL=resnet101 -python ${REAL_PATH}/imagenet.py --data_dir=${REAL_PATH}/train --train_epochs=10 --cnn_model=$CNN_MODEL --autodist_strategy=$AUTODIST_STRATEGY -# ${REAL_PATH} is the real path you place the code and dataset +python ${REAL_SCRIPT_PATH}/imagenet.py --data_dir=${REAL_DATA_PATH}/train --train_epochs=10 --cnn_model=$CNN_MODEL --autodist_strategy=$AUTODIST_STRATEGY +# ${REAL_SCRIPT_PATH} and ${REAL_DATA_PATH} are the real paths you place the code and dataset ``` #### Bidirectional Encoder Representations from Transformers (BERT) The instruction for generating the training data and setting up the pre-trained model with the config file can be found following [this link](https://github.com/tensorflow/models/tree/master/official/nlp/bert). ``` -python ${REAL_PATH}/bert.py -input_files=${REAL_PATH}/sample_data_tfrecord/*.tfrecord --bert_config_file=${REAL_PATH}/uncased_L-24_H-1024_A-16/bert_config --num_train_epochs=1 --learning_rate=5e-5 --steps_per_loop=20 --autodist_strategy=$AUTODIST_STRATEGY +python ${REAL_SCRIPT_PATH}/bert.py -input_files=${REAL_DATA_PATH}/sample_data_tfrecord/*.tfrecord --bert_config_file=${REAL_DATA_PATH}/uncased_L-24_H-1024_A-16/bert_config --num_train_epochs=1 --learning_rate=5e-5 --steps_per_loop=20 --autodist_strategy=$AUTODIST_STRATEGY ``` #### Neural Collaborative Filtering (NCF) The instruction for generating the training data can be found following [this link](https://github.com/tensorflow/models/tree/master/official/recommendation). ``` -python ${REAL_PATH}/ncf.py --default_data_dir=${REAL_PATH}/movielens --autodist_strategy=$AUTODIST_STRATEGY +python ${REAL_SCRIPT_PATH}/ncf.py --default_data_dir=${REAL_DATA_PATH}/movielens --autodist_strategy=$AUTODIST_STRATEGY ``` diff --git a/examples/benchmark/bert.py b/examples/benchmark/bert.py index b25b18e..9f992b1 100644 --- a/examples/benchmark/bert.py +++ b/examples/benchmark/bert.py @@ -31,7 +31,6 @@ from utils.logs import logger from utils.misc import keras_utils -from utils import optimization from utils import bert_modeling as modeling from utils import bert_models from utils import common_flags @@ -63,8 +62,6 @@ flags.DEFINE_integer('chunk_size', 256, 'The chunk size for training.') flags.DEFINE_integer('num_steps_per_epoch', 1000, 'Total number of training steps to run per epoch.') -flags.DEFINE_float('warmup_steps', 10, - 'Warmup steps for Adam weight decay optimizer.') flags.DEFINE_string( name='autodist_strategy', default='PS', @@ -73,10 +70,7 @@ name='autodist_patch_tf', default=True, help='AUTODIST_PATCH_TF') -flags.DEFINE_string( - name='optimizer', - default='AdamDecay', - help='the optimizer to be chosen') + flags.DEFINE_boolean(name='proxy', default=True, help='turn on off the proxy') @@ -122,7 +116,6 @@ def run_customized_training(strategy, steps_per_loop, epochs, initial_lr, - warmup_steps, input_files, train_batch_size): """Run BERT pretrain model training using low-level API.""" @@ -140,11 +133,8 @@ def _get_pretrain_model(): """Gets a pretraining model.""" pretrain_model, core_model = bert_models.pretrain_model( bert_config, max_seq_length, max_predictions_per_seq) - if FLAGS.optimizer == 'AdamDecay': - pretrain_model.optimizer = optimization.create_optimizer( - initial_lr, steps_per_epoch * epochs, warmup_steps) - else: - pretrain_model.optimizer = tf.optimizers.Adam(lr=initial_lr) + + pretrain_model.optimizer = tf.optimizers.Adam(lr=initial_lr) if FLAGS.fp16_implementation == 'graph_rewrite': pretrain_model.optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( pretrain_model.optimizer) @@ -190,7 +180,6 @@ def run_bert_pretrain(strategy, gpu_num=1, node_num=1): FLAGS.steps_per_loop, FLAGS.num_train_epochs, FLAGS.learning_rate, - FLAGS.warmup_steps, FLAGS.input_files, FLAGS.train_batch_size * gpu_num * node_num) diff --git a/examples/benchmark/utils/logs/cloud_lib.py b/examples/benchmark/utils/logs/cloud_lib.py new file mode 100644 index 0000000..a2d9bd3 --- /dev/null +++ b/examples/benchmark/utils/logs/cloud_lib.py @@ -0,0 +1,34 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities that interact with cloud service. +""" + +import requests + +GCP_METADATA_URL = "http://metadata/computeMetadata/v1/instance/hostname" +GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"} + + +def on_gcp(): + """Detect whether the current running environment is on GCP.""" + try: + # Timeout in 5 seconds, in case the test environment has connectivity issue. + # There is not default timeout, which means it might block forever. + response = requests.get( + GCP_METADATA_URL, headers=GCP_METADATA_HEADER, timeout=5) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False diff --git a/examples/benchmark/utils/logs/hooks.py b/examples/benchmark/utils/logs/hooks.py new file mode 100644 index 0000000..50ace2d --- /dev/null +++ b/examples/benchmark/utils/logs/hooks.py @@ -0,0 +1,130 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Hook that counts examples per second every N steps or seconds.""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf # pylint: disable=g-bad-import-order + +from utils.logs import logger + + +class ExamplesPerSecondHook(tf.estimator.SessionRunHook): + """Hook to print out examples per second. + + Total time is tracked and then divided by the total number of steps + to get the average step time and then batch_size is used to determine + the running average of examples per second. The examples per second for the + most recent interval is also logged. + """ + + def __init__(self, + batch_size, + every_n_steps=None, + every_n_secs=None, + warm_steps=0, + metric_logger=None): + """Initializer for ExamplesPerSecondHook. + + Args: + batch_size: Total batch size across all workers used to calculate + examples/second from global time. + every_n_steps: Log stats every n steps. + every_n_secs: Log stats every n seconds. Exactly one of the + `every_n_steps` or `every_n_secs` should be set. + warm_steps: The number of steps to be skipped before logging and running + average calculation. warm_steps steps refers to global steps across all + workers, not on each worker + metric_logger: instance of `BenchmarkLogger`, the benchmark logger that + hook should use to write the log. If None, BaseBenchmarkLogger will + be used. + + Raises: + ValueError: if neither `every_n_steps` or `every_n_secs` is set, or + both are set. + """ + + if (every_n_steps is None) == (every_n_secs is None): + raise ValueError("exactly one of every_n_steps" + " and every_n_secs should be provided.") + + self._logger = metric_logger or logger.BaseBenchmarkLogger() + + self._timer = tf.estimator.SecondOrStepTimer( + every_steps=every_n_steps, every_secs=every_n_secs) + + self._step_train_time = 0 + self._total_steps = 0 + self._batch_size = batch_size + self._warm_steps = warm_steps + # List of examples per second logged every_n_steps. + self.current_examples_per_sec_list = [] + + def begin(self): + """Called once before using the session to check global step.""" + self._global_step_tensor = tf.compat.v1.train.get_global_step() + if self._global_step_tensor is None: + raise RuntimeError( + "Global step should be created to use StepCounterHook.") + + def before_run(self, run_context): # pylint: disable=unused-argument + """Called before each call to run(). + + Args: + run_context: A SessionRunContext object. + + Returns: + A SessionRunArgs object or None if never triggered. + """ + return tf.estimator.SessionRunArgs(self._global_step_tensor) + + def after_run(self, run_context, run_values): # pylint: disable=unused-argument + """Called after each call to run(). + + Args: + run_context: A SessionRunContext object. + run_values: A SessionRunValues object. + """ + global_step = run_values.results + + if self._timer.should_trigger_for_step( + global_step) and global_step > self._warm_steps: + elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( + global_step) + if elapsed_time is not None: + self._step_train_time += elapsed_time + self._total_steps += elapsed_steps + + # average examples per second is based on the total (accumulative) + # training steps and training time so far + average_examples_per_sec = self._batch_size * ( + self._total_steps / self._step_train_time) + # current examples per second is based on the elapsed training steps + # and training time per batch + current_examples_per_sec = self._batch_size * ( + elapsed_steps / elapsed_time) + # Logs entries to be read from hook during or after run. + self.current_examples_per_sec_list.append(current_examples_per_sec) + self._logger.log_metric( + "average_examples_per_sec", average_examples_per_sec, + global_step=global_step) + + self._logger.log_metric( + "current_examples_per_sec", current_examples_per_sec, + global_step=global_step) diff --git a/examples/benchmark/utils/logs/hooks_helper.py b/examples/benchmark/utils/logs/hooks_helper.py new file mode 100644 index 0000000..bb333dc --- /dev/null +++ b/examples/benchmark/utils/logs/hooks_helper.py @@ -0,0 +1,172 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Hooks helper to return a list of TensorFlow hooks for training by name. + +More hooks can be added to this set. To add a new hook, 1) add the new hook to +the registry in HOOKS, 2) add a corresponding function that parses out necessary +parameters. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf # pylint: disable=g-bad-import-order + +from utils.logs import hooks +from utils.logs import logger +from utils.logs import metric_hook + +_TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate', + 'cross_entropy', + 'train_accuracy']) + + +def get_train_hooks(name_list, use_tpu=False, **kwargs): + """Factory for getting a list of TensorFlow hooks for training by name. + + Args: + name_list: a list of strings to name desired hook classes. Allowed: + LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined + as keys in HOOKS + use_tpu: Boolean of whether computation occurs on a TPU. This will disable + hooks altogether. + **kwargs: a dictionary of arguments to the hooks. + + Returns: + list of instantiated hooks, ready to be used in a classifier.train call. + + Raises: + ValueError: if an unrecognized name is passed. + """ + + if not name_list: + return [] + + if use_tpu: + tf.compat.v1.logging.warning('hooks_helper received name_list `{}`, but a ' + 'TPU is specified. No hooks will be used.' + .format(name_list)) + return [] + + train_hooks = [] + for name in name_list: + hook_name = HOOKS.get(name.strip().lower()) + if hook_name is None: + raise ValueError('Unrecognized training hook requested: {}'.format(name)) + else: + train_hooks.append(hook_name(**kwargs)) + + return train_hooks + + +def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): # pylint: disable=unused-argument + """Function to get LoggingTensorHook. + + Args: + every_n_iter: `int`, print the values of `tensors` once every N local + steps taken on the current worker. + tensors_to_log: List of tensor names or dictionary mapping labels to tensor + names. If not set, log _TENSORS_TO_LOG by default. + **kwargs: a dictionary of arguments to LoggingTensorHook. + + Returns: + Returns a LoggingTensorHook with a standard set of tensors that will be + printed to stdout. + """ + if tensors_to_log is None: + tensors_to_log = _TENSORS_TO_LOG + + return tf.estimator.LoggingTensorHook( + tensors=tensors_to_log, + every_n_iter=every_n_iter) + + +def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable=unused-argument + """Function to get ProfilerHook. + + Args: + model_dir: The directory to save the profile traces to. + save_steps: `int`, print profile traces every N steps. + **kwargs: a dictionary of arguments to ProfilerHook. + + Returns: + Returns a ProfilerHook that writes out timelines that can be loaded into + profiling tools like chrome://tracing. + """ + return tf.estimator.ProfilerHook(save_steps=save_steps, output_dir=model_dir) + + +def get_examples_per_second_hook(every_n_steps=100, + batch_size=128, + warm_steps=5, + **kwargs): # pylint: disable=unused-argument + """Function to get ExamplesPerSecondHook. + + Args: + every_n_steps: `int`, print current and average examples per second every + N steps. + batch_size: `int`, total batch size used to calculate examples/second from + global time. + warm_steps: skip this number of steps before logging and running average. + **kwargs: a dictionary of arguments to ExamplesPerSecondHook. + + Returns: + Returns a ProfilerHook that writes out timelines that can be loaded into + profiling tools like chrome://tracing. + """ + return hooks.ExamplesPerSecondHook( + batch_size=batch_size, every_n_steps=every_n_steps, + warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger()) + + +def get_logging_metric_hook(tensors_to_log=None, + every_n_secs=600, + **kwargs): # pylint: disable=unused-argument + """Function to get LoggingMetricHook. + + Args: + tensors_to_log: List of tensor names or dictionary mapping labels to tensor + names. If not set, log _TENSORS_TO_LOG by default. + every_n_secs: `int`, the frequency for logging the metric. Default to every + 10 mins. + **kwargs: a dictionary of arguments. + + Returns: + Returns a LoggingMetricHook that saves tensor values in a JSON format. + """ + if tensors_to_log is None: + tensors_to_log = _TENSORS_TO_LOG + return metric_hook.LoggingMetricHook( + tensors=tensors_to_log, + metric_logger=logger.get_benchmark_logger(), + every_n_secs=every_n_secs) + + +def get_step_counter_hook(**kwargs): + """Function to get StepCounterHook.""" + del kwargs + return tf.estimator.StepCounterHook() + + +# A dictionary to map one hook name and its corresponding function +HOOKS = { + 'loggingtensorhook': get_logging_tensor_hook, + 'profilerhook': get_profiler_hook, + 'examplespersecondhook': get_examples_per_second_hook, + 'loggingmetrichook': get_logging_metric_hook, + 'stepcounterhook': get_step_counter_hook +} diff --git a/examples/benchmark/utils/logs/logger.py b/examples/benchmark/utils/logs/logger.py new file mode 100644 index 0000000..db2f04e --- /dev/null +++ b/examples/benchmark/utils/logs/logger.py @@ -0,0 +1,423 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Logging utilities for benchmark. + +For collecting local environment metrics like CPU and memory, certain python +packages need be installed. See README for details. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import datetime +import json +import multiprocessing +import numbers +import os +import threading +import uuid + +from six.moves import _thread as thread +from absl import flags +import tensorflow as tf +from tensorflow.python.client import device_lib + +from utils.logs import cloud_lib + +METRIC_LOG_FILE_NAME = "metric.log" +BENCHMARK_RUN_LOG_FILE_NAME = "benchmark_run.log" +_DATE_TIME_FORMAT_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ" +GCP_TEST_ENV = "GCP" +RUN_STATUS_SUCCESS = "success" +RUN_STATUS_FAILURE = "failure" +RUN_STATUS_RUNNING = "running" + + +FLAGS = flags.FLAGS + +# Don't use it directly. Use get_benchmark_logger to access a logger. +_benchmark_logger = None +_logger_lock = threading.Lock() + + +def config_benchmark_logger(flag_obj=None): + """Config the global benchmark logger.""" + _logger_lock.acquire() + try: + global _benchmark_logger + if not flag_obj: + flag_obj = FLAGS + + if (not hasattr(flag_obj, "benchmark_logger_type") or + flag_obj.benchmark_logger_type == "BaseBenchmarkLogger"): + _benchmark_logger = BaseBenchmarkLogger() + elif flag_obj.benchmark_logger_type == "BenchmarkFileLogger": + _benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir) + elif flag_obj.benchmark_logger_type == "BenchmarkBigQueryLogger": + from benchmark import benchmark_uploader as bu # pylint: disable=g-import-not-at-top + bq_uploader = bu.BigQueryUploader(gcp_project=flag_obj.gcp_project) + _benchmark_logger = BenchmarkBigQueryLogger( + bigquery_uploader=bq_uploader, + bigquery_data_set=flag_obj.bigquery_data_set, + bigquery_run_table=flag_obj.bigquery_run_table, + bigquery_run_status_table=flag_obj.bigquery_run_status_table, + bigquery_metric_table=flag_obj.bigquery_metric_table, + run_id=str(uuid.uuid4())) + else: + raise ValueError("Unrecognized benchmark_logger_type: %s" + % flag_obj.benchmark_logger_type) + + finally: + _logger_lock.release() + return _benchmark_logger + + +def get_benchmark_logger(): + if not _benchmark_logger: + config_benchmark_logger() + return _benchmark_logger + + +@contextlib.contextmanager +def benchmark_context(flag_obj): + """Context of benchmark, which will update status of the run accordingly.""" + benchmark_logger = config_benchmark_logger(flag_obj) + try: + yield + benchmark_logger.on_finish(RUN_STATUS_SUCCESS) + except Exception: # pylint: disable=broad-except + # Catch all the exception, update the run status to be failure, and re-raise + benchmark_logger.on_finish(RUN_STATUS_FAILURE) + raise + + +class BaseBenchmarkLogger(object): + """Class to log the benchmark information to STDOUT.""" + + def log_evaluation_result(self, eval_results): + """Log the evaluation result. + + The evaluate result is a dictionary that contains metrics defined in + model_fn. It also contains a entry for global_step which contains the value + of the global step when evaluation was performed. + + Args: + eval_results: dict, the result of evaluate. + """ + if not isinstance(eval_results, dict): + tf.compat.v1.logging.warning( + "eval_results should be dictionary for logging. Got %s", + type(eval_results)) + return + global_step = eval_results[tf.compat.v1.GraphKeys.GLOBAL_STEP] + for key in sorted(eval_results): + if key != tf.compat.v1.GraphKeys.GLOBAL_STEP: + self.log_metric(key, eval_results[key], global_step=global_step) + + def log_metric(self, name, value, unit=None, global_step=None, extras=None): + """Log the benchmark metric information to local file. + + Currently the logging is done in a synchronized way. This should be updated + to log asynchronously. + + Args: + name: string, the name of the metric to log. + value: number, the value of the metric. The value will not be logged if it + is not a number type. + unit: string, the unit of the metric, E.g "image per second". + global_step: int, the global_step when the metric is logged. + extras: map of string:string, the extra information about the metric. + """ + metric = _process_metric_to_json(name, value, unit, global_step, extras) + if metric: + tf.compat.v1.logging.info("Benchmark metric: %s", metric) + + def log_run_info(self, model_name, dataset_name, run_params, test_id=None): + tf.compat.v1.logging.info( + "Benchmark run: %s", _gather_run_info(model_name, dataset_name, + run_params, test_id)) + + def on_finish(self, status): + pass + + +class BenchmarkFileLogger(BaseBenchmarkLogger): + """Class to log the benchmark information to local disk.""" + + def __init__(self, logging_dir): + super(BenchmarkFileLogger, self).__init__() + self._logging_dir = logging_dir + if not tf.io.gfile.isdir(self._logging_dir): + tf.io.gfile.makedirs(self._logging_dir) + self._metric_file_handler = tf.io.gfile.GFile( + os.path.join(self._logging_dir, METRIC_LOG_FILE_NAME), "a") + + def log_metric(self, name, value, unit=None, global_step=None, extras=None): + """Log the benchmark metric information to local file. + + Currently the logging is done in a synchronized way. This should be updated + to log asynchronously. + + Args: + name: string, the name of the metric to log. + value: number, the value of the metric. The value will not be logged if it + is not a number type. + unit: string, the unit of the metric, E.g "image per second". + global_step: int, the global_step when the metric is logged. + extras: map of string:string, the extra information about the metric. + """ + metric = _process_metric_to_json(name, value, unit, global_step, extras) + if metric: + try: + json.dump(metric, self._metric_file_handler) + self._metric_file_handler.write("\n") + self._metric_file_handler.flush() + except (TypeError, ValueError) as e: + tf.compat.v1.logging.warning( + "Failed to dump metric to log file: name %s, value %s, error %s", + name, value, e) + + def log_run_info(self, model_name, dataset_name, run_params, test_id=None): + """Collect most of the TF runtime information for the local env. + + The schema of the run info follows official/benchmark/datastore/schema. + + Args: + model_name: string, the name of the model. + dataset_name: string, the name of dataset for training and evaluation. + run_params: dict, the dictionary of parameters for the run, it could + include hyperparameters or other params that are important for the run. + test_id: string, the unique name of the test run by the combination of key + parameters, eg batch size, num of GPU. It is hardware independent. + """ + run_info = _gather_run_info(model_name, dataset_name, run_params, test_id) + + with tf.io.gfile.GFile(os.path.join( + self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f: + try: + json.dump(run_info, f) + f.write("\n") + except (TypeError, ValueError) as e: + tf.compat.v1.logging.warning( + "Failed to dump benchmark run info to log file: %s", e) + + def on_finish(self, status): + self._metric_file_handler.flush() + self._metric_file_handler.close() + + +class BenchmarkBigQueryLogger(BaseBenchmarkLogger): + """Class to log the benchmark information to BigQuery data store.""" + + def __init__(self, + bigquery_uploader, + bigquery_data_set, + bigquery_run_table, + bigquery_run_status_table, + bigquery_metric_table, + run_id): + super(BenchmarkBigQueryLogger, self).__init__() + self._bigquery_uploader = bigquery_uploader + self._bigquery_data_set = bigquery_data_set + self._bigquery_run_table = bigquery_run_table + self._bigquery_run_status_table = bigquery_run_status_table + self._bigquery_metric_table = bigquery_metric_table + self._run_id = run_id + + def log_metric(self, name, value, unit=None, global_step=None, extras=None): + """Log the benchmark metric information to bigquery. + + Args: + name: string, the name of the metric to log. + value: number, the value of the metric. The value will not be logged if it + is not a number type. + unit: string, the unit of the metric, E.g "image per second". + global_step: int, the global_step when the metric is logged. + extras: map of string:string, the extra information about the metric. + """ + metric = _process_metric_to_json(name, value, unit, global_step, extras) + if metric: + # Starting new thread for bigquery upload in case it might take long time + # and impact the benchmark and performance measurement. Starting a new + # thread might have potential performance impact for model that run on + # CPU. + thread.start_new_thread( + self._bigquery_uploader.upload_benchmark_metric_json, + (self._bigquery_data_set, + self._bigquery_metric_table, + self._run_id, + [metric])) + + def log_run_info(self, model_name, dataset_name, run_params, test_id=None): + """Collect most of the TF runtime information for the local env. + + The schema of the run info follows official/benchmark/datastore/schema. + + Args: + model_name: string, the name of the model. + dataset_name: string, the name of dataset for training and evaluation. + run_params: dict, the dictionary of parameters for the run, it could + include hyperparameters or other params that are important for the run. + test_id: string, the unique name of the test run by the combination of key + parameters, eg batch size, num of GPU. It is hardware independent. + """ + run_info = _gather_run_info(model_name, dataset_name, run_params, test_id) + # Starting new thread for bigquery upload in case it might take long time + # and impact the benchmark and performance measurement. Starting a new + # thread might have potential performance impact for model that run on CPU. + thread.start_new_thread( + self._bigquery_uploader.upload_benchmark_run_json, + (self._bigquery_data_set, + self._bigquery_run_table, + self._run_id, + run_info)) + thread.start_new_thread( + self._bigquery_uploader.insert_run_status, + (self._bigquery_data_set, + self._bigquery_run_status_table, + self._run_id, + RUN_STATUS_RUNNING)) + + def on_finish(self, status): + self._bigquery_uploader.update_run_status( + self._bigquery_data_set, + self._bigquery_run_status_table, + self._run_id, + status) + + +def _gather_run_info(model_name, dataset_name, run_params, test_id): + """Collect the benchmark run information for the local environment.""" + run_info = { + "model_name": model_name, + "dataset": {"name": dataset_name}, + "machine_config": {}, + "test_id": test_id, + "run_date": datetime.datetime.utcnow().strftime( + _DATE_TIME_FORMAT_PATTERN)} + _collect_tensorflow_info(run_info) + _collect_tensorflow_environment_variables(run_info) + _collect_run_params(run_info, run_params) + _collect_cpu_info(run_info) + _collect_memory_info(run_info) + _collect_test_environment(run_info) + return run_info + + +def _process_metric_to_json( + name, value, unit=None, global_step=None, extras=None): + """Validate the metric data and generate JSON for insert.""" + if not isinstance(value, numbers.Number): + tf.compat.v1.logging.warning( + "Metric value to log should be a number. Got %s", type(value)) + return None + + extras = _convert_to_json_dict(extras) + return { + "name": name, + "value": float(value), + "unit": unit, + "global_step": global_step, + "timestamp": datetime.datetime.utcnow().strftime( + _DATE_TIME_FORMAT_PATTERN), + "extras": extras} + + +def _collect_tensorflow_info(run_info): + run_info["tensorflow_version"] = { + "version": tf.version.VERSION, "git_hash": tf.version.GIT_VERSION} + + +def _collect_run_params(run_info, run_params): + """Log the parameter information for the benchmark run.""" + def process_param(name, value): + type_check = { + str: {"name": name, "string_value": value}, + int: {"name": name, "long_value": value}, + bool: {"name": name, "bool_value": str(value)}, + float: {"name": name, "float_value": value}, + } + return type_check.get(type(value), + {"name": name, "string_value": str(value)}) + if run_params: + run_info["run_parameters"] = [ + process_param(k, v) for k, v in sorted(run_params.items())] + + +def _collect_tensorflow_environment_variables(run_info): + run_info["tensorflow_environment_variables"] = [ + {"name": k, "value": v} + for k, v in sorted(os.environ.items()) if k.startswith("TF_")] + + +# The following code is mirrored from tensorflow/tools/test/system_info_lib +# which is not exposed for import. +def _collect_cpu_info(run_info): + """Collect the CPU information for the local environment.""" + cpu_info = {} + + cpu_info["num_cores"] = multiprocessing.cpu_count() + + try: + # Note: cpuinfo is not installed in the TensorFlow OSS tree. + # It is installable via pip. + import cpuinfo # pylint: disable=g-import-not-at-top + + info = cpuinfo.get_cpu_info() + cpu_info["cpu_info"] = info["brand"] + cpu_info["mhz_per_cpu"] = info["hz_advertised_raw"][0] / 1.0e6 + + run_info["machine_config"]["cpu_info"] = cpu_info + except ImportError: + tf.compat.v1.logging.warn( + "'cpuinfo' not imported. CPU info will not be logged.") + + +def _collect_memory_info(run_info): + try: + # Note: psutil is not installed in the TensorFlow OSS tree. + # It is installable via pip. + import psutil # pylint: disable=g-import-not-at-top + vmem = psutil.virtual_memory() + run_info["machine_config"]["memory_total"] = vmem.total + run_info["machine_config"]["memory_available"] = vmem.available + except ImportError: + tf.compat.v1.logging.warn( + "'psutil' not imported. Memory info will not be logged.") + + +def _collect_test_environment(run_info): + """Detect the local environment, eg GCE, AWS or DGX, etc.""" + if cloud_lib.on_gcp(): + run_info["test_environment"] = GCP_TEST_ENV + # TODO(scottzhu): Add more testing env detection for other platform + + +def _parse_gpu_model(physical_device_desc): + # Assume all the GPU connected are same model + for kv in physical_device_desc.split(","): + k, _, v = kv.partition(":") + if k.strip() == "name": + return v.strip() + return None + + +def _convert_to_json_dict(input_dict): + if input_dict: + return [{"name": k, "value": v} for k, v in sorted(input_dict.items())] + else: + return [] diff --git a/examples/benchmark/utils/logs/metric_hook.py b/examples/benchmark/utils/logs/metric_hook.py new file mode 100644 index 0000000..f408e3e --- /dev/null +++ b/examples/benchmark/utils/logs/metric_hook.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Session hook for logging benchmark metric.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf # pylint: disable=g-bad-import-order + + +class LoggingMetricHook(tf.estimator.LoggingTensorHook): + """Hook to log benchmark metric information. + + This hook is very similar as tf.train.LoggingTensorHook, which logs given + tensors every N local steps, every N seconds, or at the end. The metric + information will be logged to given log_dir or via metric_logger in JSON + format, which can be consumed by data analysis pipeline later. + + Note that if `at_end` is True, `tensors` should not include any tensor + whose evaluation produces a side effect such as consuming additional inputs. + """ + + def __init__(self, tensors, metric_logger=None, + every_n_iter=None, every_n_secs=None, at_end=False): + """Initializer for LoggingMetricHook. + + Args: + tensors: `dict` that maps string-valued tags to tensors/tensor names, + or `iterable` of tensors/tensor names. + metric_logger: instance of `BenchmarkLogger`, the benchmark logger that + hook should use to write the log. + every_n_iter: `int`, print the values of `tensors` once every N local + steps taken on the current worker. + every_n_secs: `int` or `float`, print the values of `tensors` once every N + seconds. Exactly one of `every_n_iter` and `every_n_secs` should be + provided. + at_end: `bool` specifying whether to print the values of `tensors` at the + end of the run. + + Raises: + ValueError: + 1. `every_n_iter` is non-positive, or + 2. Exactly one of every_n_iter and every_n_secs should be provided. + 3. Exactly one of log_dir and metric_logger should be provided. + """ + super(LoggingMetricHook, self).__init__( + tensors=tensors, + every_n_iter=every_n_iter, + every_n_secs=every_n_secs, + at_end=at_end) + + if metric_logger is None: + raise ValueError("metric_logger should be provided.") + self._logger = metric_logger + + def begin(self): + super(LoggingMetricHook, self).begin() + self._global_step_tensor = tf.compat.v1.train.get_global_step() + if self._global_step_tensor is None: + raise RuntimeError( + "Global step should be created to use LoggingMetricHook.") + if self._global_step_tensor.name not in self._current_tensors: + self._current_tensors[self._global_step_tensor.name] = ( + self._global_step_tensor) + + def after_run(self, unused_run_context, run_values): + # should_trigger is a internal state that populated at before_run, and it is + # using self_timer to determine whether it should trigger. + if self._should_trigger: + self._log_metric(run_values.results) + + self._iter_count += 1 + + def end(self, session): + if self._log_at_end: + values = session.run(self._current_tensors) + self._log_metric(values) + + def _log_metric(self, tensor_values): + self._timer.update_last_triggered_step(self._iter_count) + global_step = tensor_values[self._global_step_tensor.name] + # self._tag_order is populated during the init of LoggingTensorHook + for tag in self._tag_order: + self._logger.log_metric(tag, tensor_values[tag], global_step=global_step) diff --git a/examples/benchmark/utils/logs/mlperf_helper.py b/examples/benchmark/utils/logs/mlperf_helper.py new file mode 100644 index 0000000..c9c0434 --- /dev/null +++ b/examples/benchmark/utils/logs/mlperf_helper.py @@ -0,0 +1,192 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Wrapper for the mlperf logging utils. + +MLPerf compliance logging is only desired under a limited set of circumstances. +This module is intended to keep users from needing to consider logging (or +install the module) unless they are performing mlperf runs. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import json +import os +import re +import subprocess +import sys +import typing + +import tensorflow as tf + +_MIN_VERSION = (0, 0, 10) +_STACK_OFFSET = 2 + +SUDO = "sudo" if os.geteuid() else "" + +# This indirection is used in docker. +DROP_CACHE_LOC = os.getenv("DROP_CACHE_LOC", "/proc/sys/vm/drop_caches") + +_NCF_PREFIX = "NCF_RAW_" + +# TODO(robieta): move line parsing to mlperf util +_PREFIX = r"(?:{})?:::MLPv([0-9]+).([0-9]+).([0-9]+)".format(_NCF_PREFIX) +_BENCHMARK = r"([a-zA-Z0-9_]+)" +_TIMESTAMP = r"([0-9]+\.[0-9]+)" +_CALLSITE = r"\((.+):([0-9]+)\)" +_TAG = r"([a-zA-Z0-9_]+)" +_VALUE = r"(.*)" + +ParsedLine = namedtuple("ParsedLine", ["version", "benchmark", "timestamp", + "callsite", "tag", "value"]) + +LINE_PATTERN = re.compile( + "^{prefix} {benchmark} {timestamp} {callsite} {tag}(: |$){value}?$".format( + prefix=_PREFIX, benchmark=_BENCHMARK, timestamp=_TIMESTAMP, + callsite=_CALLSITE, tag=_TAG, value=_VALUE)) + + +def parse_line(line): # type: (str) -> typing.Optional[ParsedLine] + match = LINE_PATTERN.match(line.strip()) + if not match: + return + + major, minor, micro, benchmark, timestamp = match.groups()[:5] + call_file, call_line, tag, _, value = match.groups()[5:] + + return ParsedLine(version=(int(major), int(minor), int(micro)), + benchmark=benchmark, timestamp=timestamp, + callsite=(call_file, call_line), tag=tag, value=value) + + +def unparse_line(parsed_line): # type: (ParsedLine) -> str + version_str = "{}.{}.{}".format(*parsed_line.version) + callsite_str = "({}:{})".format(*parsed_line.callsite) + value_str = ": {}".format(parsed_line.value) if parsed_line.value else "" + return ":::MLPv{} {} {} {} {} {}".format( + version_str, parsed_line.benchmark, parsed_line.timestamp, callsite_str, + parsed_line.tag, value_str) + + +def get_mlperf_log(): + """Shielded import of mlperf_log module.""" + try: + import mlperf_compliance + + def test_mlperf_log_pip_version(): + """Check that mlperf_compliance is up to date.""" + import pkg_resources + version = pkg_resources.get_distribution("mlperf_compliance") + version = tuple(int(i) for i in version.version.split(".")) + if version < _MIN_VERSION: + tf.compat.v1.logging.warning( + "mlperf_compliance is version {}, must be >= {}".format( + ".".join([str(i) for i in version]), + ".".join([str(i) for i in _MIN_VERSION]))) + raise ImportError + return mlperf_compliance.mlperf_log + + mlperf_log = test_mlperf_log_pip_version() + + except ImportError: + mlperf_log = None + + return mlperf_log + + +class Logger(object): + """MLPerf logger indirection class. + + This logger only logs for MLPerf runs, and prevents various errors associated + with not having the mlperf_compliance package installed. + """ + class Tags(object): + def __init__(self, mlperf_log): + self._enabled = False + self._mlperf_log = mlperf_log + + def __getattr__(self, item): + if self._mlperf_log is None or not self._enabled: + return + return getattr(self._mlperf_log, item) + + def __init__(self): + self._enabled = False + self._mlperf_log = get_mlperf_log() + self.tags = self.Tags(self._mlperf_log) + + def __call__(self, enable=False): + if enable and self._mlperf_log is None: + raise ImportError("MLPerf logging was requested, but mlperf_compliance " + "module could not be loaded.") + + self._enabled = enable + self.tags._enabled = enable + return self + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + self._enabled = False + self.tags._enabled = False + + @property + def log_file(self): + if self._mlperf_log is None: + return + return self._mlperf_log.LOG_FILE + + @property + def enabled(self): + return self._enabled + + def ncf_print(self, key, value=None, stack_offset=_STACK_OFFSET, + deferred=False, extra_print=False, prefix=_NCF_PREFIX): + if self._mlperf_log is None or not self.enabled: + return + self._mlperf_log.ncf_print(key=key, value=value, stack_offset=stack_offset, + deferred=deferred, extra_print=extra_print, + prefix=prefix) + + def set_ncf_root(self, path): + if self._mlperf_log is None: + return + self._mlperf_log.ROOT_DIR_NCF = path + + +LOGGER = Logger() +ncf_print, set_ncf_root = LOGGER.ncf_print, LOGGER.set_ncf_root +TAGS = LOGGER.tags + + +def clear_system_caches(): + if not LOGGER.enabled: + return + ret_code = subprocess.call( + ["sync && echo 3 | {} tee {}".format(SUDO, DROP_CACHE_LOC)], + shell=True) + + if ret_code: + raise ValueError("Failed to clear caches") + + +if __name__ == "__main__": + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) + with LOGGER(True): + ncf_print(key=TAGS.RUN_START) diff --git a/examples/benchmark/utils/optimization.py b/examples/benchmark/utils/optimization.py deleted file mode 100644 index 275aac3..0000000 --- a/examples/benchmark/utils/optimization.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions and classes related to optimization (weight updates).""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import re - -import tensorflow as tf - - -class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): - """Applys a warmup schedule on a given learning rate decay schedule.""" - - def __init__( - self, - initial_learning_rate, - decay_schedule_fn, - warmup_steps, - power=1.0, - name=None): - super(WarmUp, self).__init__() - self.initial_learning_rate = initial_learning_rate - self.warmup_steps = warmup_steps - self.power = power - self.decay_schedule_fn = decay_schedule_fn - self.name = name - - def __call__(self, step): - with tf.name_scope(self.name or 'WarmUp') as name: - # Implements polynomial warmup. i.e., if global_step < warmup_steps, the - # learning rate will be `global_step/num_warmup_steps * init_lr`. - global_step_float = tf.cast(step, tf.float32) - warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) - warmup_percent_done = global_step_float / warmup_steps_float - warmup_learning_rate = ( - self.initial_learning_rate * - tf.math.pow(warmup_percent_done, self.power)) - return tf.cond(global_step_float < warmup_steps_float, - lambda: warmup_learning_rate, - lambda: self.decay_schedule_fn(step), - name=name) - - def get_config(self): - return { - 'initial_learning_rate': self.initial_learning_rate, - 'decay_schedule_fn': self.decay_schedule_fn, - 'warmup_steps': self.warmup_steps, - 'power': self.power, - 'name': self.name - } - - -def create_optimizer(init_lr, num_train_steps, num_warmup_steps): - """Creates an optimizer with learning rate schedule.""" - # Implements linear decay of the learning rate. - learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( - initial_learning_rate=init_lr, - decay_steps=num_train_steps, - end_learning_rate=0.0) - if num_warmup_steps: - learning_rate_fn = WarmUp(initial_learning_rate=init_lr, - decay_schedule_fn=learning_rate_fn, - warmup_steps=num_warmup_steps) - optimizer = AdamWeightDecay( - learning_rate=learning_rate_fn, - weight_decay_rate=0.01, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-6, - exclude_from_weight_decay=['layer_norm', 'bias']) - return optimizer - - -class AdamWeightDecay(tf.keras.optimizers.Adam): - """Adam enables L2 weight decay and clip_by_global_norm on gradients. - - Just adding the square of the weights to the loss function is *not* the - correct way of using L2 regularization/weight decay with Adam, since that will - interact with the m and v parameters in strange ways. - - Instead we want ot decay the weights in a manner that doesn't interact with - the m/v parameters. This is equivalent to adding the square of the weights to - the loss with plain (non-momentum) SGD. - """ - - def __init__(self, - learning_rate=0.001, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-7, - amsgrad=False, - weight_decay_rate=0.0, - include_in_weight_decay=None, - exclude_from_weight_decay=None, - name='AdamWeightDecay', - **kwargs): - super(AdamWeightDecay, self).__init__( - learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) - self.weight_decay_rate = weight_decay_rate - self._include_in_weight_decay = include_in_weight_decay - self._exclude_from_weight_decay = exclude_from_weight_decay - - @classmethod - def from_config(cls, config): - """Creates an optimizer from its config with WarmUp custom object.""" - custom_objects = {'WarmUp': WarmUp} - return super(AdamWeightDecay, cls).from_config( - config, custom_objects=custom_objects) - - def _prepare_local(self, var_device, var_dtype, apply_state): - super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, - apply_state) - apply_state['weight_decay_rate'] = tf.constant( - self.weight_decay_rate, name='adam_weight_decay_rate') - - def _decay_weights_op(self, var, learning_rate, apply_state): - do_decay = self._do_use_weight_decay(var.name) - if do_decay: - return var.assign_sub( - learning_rate * var * - apply_state['weight_decay_rate'], - use_locking=self._use_locking) - return tf.no_op() - - def apply_gradients(self, grads_and_vars, name=None): - grads, tvars = list(zip(*grads_and_vars)) - (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) - return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars)) - - def _get_lr(self, var_device, var_dtype, apply_state): - """Retrieves the learning rate with the given state.""" - if apply_state is None: - return self._decayed_lr_t[var_dtype], {} - - apply_state = apply_state or {} - coefficients = apply_state.get((var_device, var_dtype)) - if coefficients is None: - coefficients = self._fallback_apply_state(var_device, var_dtype) - apply_state[(var_device, var_dtype)] = coefficients - - return coefficients['lr_t'], dict(apply_state=apply_state) - - def _resource_apply_dense(self, grad, var, apply_state=None): - lr_t, kwargs = self._get_lr( - var.device, var.dtype.base_dtype, apply_state) - decay = self._decay_weights_op(var, lr_t, apply_state) - with tf.control_dependencies([decay]): - return super(AdamWeightDecay, self)._resource_apply_dense( - grad, var, **kwargs) - - def _resource_apply_sparse(self, grad, var, indices, apply_state=None): - lr_t, kwargs = self._get_lr( - var.device, var.dtype.base_dtype, apply_state) - decay = self._decay_weights_op(var, lr_t, apply_state) - with tf.control_dependencies([decay]): - return super(AdamWeightDecay, self)._resource_apply_sparse( - grad, var, indices, **kwargs) - - def get_config(self): - config = super(AdamWeightDecay, self).get_config() - config.update({ - 'weight_decay_rate': self.weight_decay_rate, - }) - return config - - def _do_use_weight_decay(self, param_name): - """Whether to use L2 weight decay for `param_name`.""" - if self.weight_decay_rate == 0: - return False - - if self._include_in_weight_decay: - for r in self._include_in_weight_decay: - if re.search(r, param_name) is not None: - return True - - if self._exclude_from_weight_decay: - for r in self._exclude_from_weight_decay: - if re.search(r, param_name) is not None: - return False - return True