From 738ed105e65d347bac5cad9406431dc610d7056b Mon Sep 17 00:00:00 2001 From: Michael Weiss Date: Thu, 16 Feb 2023 11:41:11 +0100 Subject: [PATCH] :bug: fix multi-gpu issues (#156) Introduces a new class DeviceAllocatorContextManagerV2 which does not require experimental APIs anymore and uses dynamic memory growth. It fully backwards compatible: just replacing the class extension form DeviceAllocatorContextManagerV2 with DeviceAllocatorContextManagerV2 does the full migration. The gpu_memory_limit function is not called anymore and can be removed. This closes #75 --- examples/multi_device.py | 71 +++++--- .../models/ensemble_utils/__init__.py | 2 + .../models/ensemble_utils/_lazy_contexts.py | 170 ++++++++++++------ 3 files changed, 164 insertions(+), 79 deletions(-) diff --git a/examples/multi_device.py b/examples/multi_device.py index 25222aa..74e7e57 100644 --- a/examples/multi_device.py +++ b/examples/multi_device.py @@ -9,8 +9,7 @@ import uncertainty_wizard as uwiz -class MultiGpuContext(uwiz.models.ensemble_utils.DeviceAllocatorContextManager): - +class MultiGpuContext(uwiz.models.ensemble_utils.DeviceAllocatorContextManagerV2): @classmethod def file_path(cls) -> str: return "temp-ensemble.txt" @@ -27,43 +26,63 @@ def virtual_devices_per_gpu(cls) -> Dict[int, int]: # Here, we configure a setting with two gpus # On gpu 0, two atomic models will be executed at the same time # On gpu 1, three atomic models will be executed at the same time - return { - 0: 2, - 1: 3 - } - - @classmethod - def gpu_memory_limit(cls) -> int: - return 1500 + return {0: 2, 1: 3} def train_model(model_id): import tensorflow as tf model = tf.keras.models.Sequential() - model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same', - input_shape=(32, 32, 3))) - model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same')) + model.add( + tf.keras.layers.Conv2D( + 32, + kernel_size=(3, 3), + activation="relu", + padding="same", + input_shape=(32, 32, 3), + ) + ) + model.add( + tf.keras.layers.Conv2D( + 32, kernel_size=(3, 3), activation="relu", padding="same" + ) + ) model.add(tf.keras.layers.MaxPooling2D((2, 2))) model.add(tf.keras.layers.Dropout(0.2)) - model.add(tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same')) - model.add(tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same')) + model.add( + tf.keras.layers.Conv2D( + 64, kernel_size=(3, 3), activation="relu", padding="same" + ) + ) + model.add( + tf.keras.layers.Conv2D( + 64, kernel_size=(3, 3), activation="relu", padding="same" + ) + ) model.add(tf.keras.layers.MaxPooling2D((2, 2))) model.add(tf.keras.layers.Dropout(0.2)) - model.add(tf.keras.layers.Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same')) - model.add(tf.keras.layers.Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same')) + model.add( + tf.keras.layers.Conv2D( + 128, kernel_size=(3, 3), activation="relu", padding="same" + ) + ) + model.add( + tf.keras.layers.Conv2D( + 128, kernel_size=(3, 3), activation="relu", padding="same" + ) + ) model.add(tf.keras.layers.MaxPooling2D((2, 2))) model.add(tf.keras.layers.Dropout(0.2)) model.add(tf.keras.layers.Flatten()) - model.add(tf.keras.layers.Dense(128, activation='relu')) + model.add(tf.keras.layers.Dense(128, activation="relu")) model.add(tf.keras.layers.Dropout(0.2)) - model.add(tf.keras.layers.Dense(10, activation='softmax')) + model.add(tf.keras.layers.Dense(10, activation="softmax")) opt = tf.keras.optimizers.SGD(lr=0.001, momentum=0.9) - model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy']) + model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["accuracy"]) (x_train, y_train), (_, _) = tf.keras.datasets.cifar10.load_data() - x_train = x_train / 255. + x_train = x_train / 255.0 y_train = tf.keras.utils.to_categorical(y_train, 10) # For the sake of this example, let's use just one epoch. @@ -73,7 +92,7 @@ def train_model(model_id): return model, "history_not_returned" -if __name__ == '__main__': +if __name__ == "__main__": # Make sure the training data is cached on the fs before the multiprocessing starts # Otherwise, all processes will simultaneously attempt to download and cache data, # which will fail as they break each others caches @@ -81,8 +100,12 @@ def train_model(model_id): # set this path to where you want to save the ensemble temp_dir = "/tmp/ensemble" - ensemble = uwiz.models.LazyEnsemble(num_models=20, model_save_path=temp_dir, delete_existing=True, - default_num_processes=5) + ensemble = uwiz.models.LazyEnsemble( + num_models=20, + model_save_path=temp_dir, + delete_existing=True, + default_num_processes=5, + ) ensemble.create(train_model, context=MultiGpuContext) print("Ensemble was successfully trained") diff --git a/uncertainty_wizard/models/ensemble_utils/__init__.py b/uncertainty_wizard/models/ensemble_utils/__init__.py index 0bcefe9..67e2417 100644 --- a/uncertainty_wizard/models/ensemble_utils/__init__.py +++ b/uncertainty_wizard/models/ensemble_utils/__init__.py @@ -3,6 +3,7 @@ "DynamicGpuGrowthContextManager", "NoneContextManager", "DeviceAllocatorContextManager", + "DeviceAllocatorContextManagerV2", "CpuOnlyContextManager", "SaveConfig", ] @@ -10,6 +11,7 @@ from ._lazy_contexts import ( CpuOnlyContextManager, DeviceAllocatorContextManager, + DeviceAllocatorContextManagerV2, DynamicGpuGrowthContextManager, EnsembleContextManager, NoneContextManager, diff --git a/uncertainty_wizard/models/ensemble_utils/_lazy_contexts.py b/uncertainty_wizard/models/ensemble_utils/_lazy_contexts.py index a556d17..e9e2d4f 100644 --- a/uncertainty_wizard/models/ensemble_utils/_lazy_contexts.py +++ b/uncertainty_wizard/models/ensemble_utils/_lazy_contexts.py @@ -3,7 +3,9 @@ import os import pickle import time -from typing import Dict +import warnings +from abc import ABC +from typing import Dict, Optional import tensorflow as tf @@ -219,23 +221,11 @@ def enable_dynamic_gpu_growth(cls): global device_id -class DeviceAllocatorContextManager(EnsembleContextManager, abc.ABC): - """ - This context manager configures tensorflow such a user-defined amount of processes for every available gpu - are started. In addition, running a process on the CPU can be enabled. - - This is an abstract context manager. To use it, one has to subclass it and override (at least) - the abstract methods. - """ +class _DeviceAllocatorContextManagerAbs(EnsembleContextManager, abc.ABC): + """Abstract parent class for all managers that allocate processes to specific devices.""" - def __init__(self): - super().__init__() - if not current_tf_version_is_older_than("2.10.0"): - raise RuntimeError( - "The DeviceAllocatorContextManager is not compatible with tensorflow 2.10.0 " - "or newer. Please fall back to a single GPU for now (see issue #75)," - "or downgrade to tensorflow 2.9.0." - ) + def __init__(self, model_id: int, varargs: dict = None): + super().__init__(model_id, varargs) # docstr-coverage: inherited def __enter__(self) -> "DeviceAllocatorContextManager": @@ -255,7 +245,6 @@ def __enter__(self) -> "DeviceAllocatorContextManager": # docstr-coverage: inherited def __exit__(self, type, value, traceback) -> None: super().__exit__(type, value, traceback) - global number_of_tasks_in_this_process global device_id if number_of_tasks_in_this_process == self.max_sequential_tasks_per_process(): @@ -333,18 +322,6 @@ def virtual_devices_per_gpu(cls) -> Dict[int, int]: :return: A mapping specifying how many processes of this ensemble should run concurrently per gpu. """ - @classmethod - @abc.abstractmethod - def gpu_memory_limit(cls) -> int: - """ - Override this method to specify the amount of MB which should be used - when creating the virtual device on the GPU. Ignored for CPUs. - - *Attention:* This function must be pure: Repeated calls should always return the same value. - - :return: The amount of MB which will be reserved on the selected gpu in the created context. - """ - @classmethod def acquire_lock_timeout(cls) -> int: """ @@ -433,31 +410,9 @@ def _pick_device(cls, availablilities) -> int: print(f"Availabilities: {availablilities}. Picked Device {picked_device}") return picked_device - @classmethod - def _use_gpu(cls, index: int): - size = cls.gpu_memory_limit() - gpus = tf.config.experimental.list_physical_devices("GPU") - - # Check if selected gpu can be found - if gpus is None or len(gpus) <= index: - raise ValueError( - f"Uncertainty Wizards DeviceAllocatorContextManager was configured to use gpu {index} " - f"but no no such gpu was found. " - ) - - try: - tf.config.set_visible_devices([gpus[index]], "GPU") - tf.config.experimental.set_virtual_device_configuration( - gpus[index], - [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=size)], - ) - logical_gpus = tf.config.experimental.list_logical_devices("GPU") - print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") - except RuntimeError as e: - raise ValueError( - f"Uncertainty Wizard was unable to create a virtual device " - f"on gpu {index} and memory limit {size}MB" - ) from e + @abc.abstractmethod + def _use_gpu(self, index: int): + pass @classmethod def _acquire_lock(cls) -> int: @@ -501,3 +456,108 @@ def _acquire_lock(cls) -> int: def _release_lock(cls, lockfile: int): os.close(lockfile) os.remove(cls._lock_file_path()) + + +class DeviceAllocatorContextManager(_DeviceAllocatorContextManagerAbs, ABC): + """DEPRECATED. Please use DeviceAllocatorContextManagerV2 instead. + + This context manager configures tensorflow such a user-defined amount of processes for every available gpu + are started. In addition, running a process on the CPU can be enabled. + + This is an abstract context manager. To use it, one has to subclass it and override (at least) + the abstract methods. + """ + + def __init__(self, model_id: int, varargs: dict = None): + super().__init__(model_id, varargs) + if not current_tf_version_is_older_than("2.10.0"): + raise RuntimeError( + "The DeviceAllocatorContextManager is not compatible with tensorflow 2.10.0 " + "or newer. Please use DeviceAllocatorContextManagerV2 instead." + ) + + warnings.warn( + "DeviceAllocatorContextManager is deprecated. " + "Please use DeviceAllocatorContextManagerV2 instead. " + "Migration is easy, just extend DeviceAllocatorContextManagerV2 " + "instead of DeviceAllocatorContextManager. " + "and remove the `gpu_memory_limit` method from your extension. ", + DeprecationWarning, + ) + + def _use_gpu(self, index: int): + size = self.gpu_memory_limit() + gpus = tf.config.experimental.list_physical_devices("GPU") + + # Check if selected gpu can be found + if gpus is None or len(gpus) <= index: + raise ValueError( + f"Uncertainty Wizards DeviceAllocatorContextManager was configured to use gpu {index} " + f"but no no such gpu was found. " + ) + + try: + tf.config.set_visible_devices([gpus[index]], "GPU") + tf.config.experimental.set_virtual_device_configuration( + gpus[index], + [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=size)], + ) + logical_gpus = tf.config.experimental.list_logical_devices("GPU") + print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") + except RuntimeError as e: + raise ValueError( + f"Uncertainty Wizard was unable to create a virtual device " + f"on gpu {index} and memory limit {size}MB" + ) from e + + @classmethod + @abc.abstractmethod + def gpu_memory_limit(cls) -> Optional[int]: + """ + Override this method to specify the amount of MB which should be used + when creating the virtual device on the GPU. Ignored for CPUs. + + *Attention:* This function must be pure: Repeated calls should always return the same value. + + :return: The amount of MB which will be reserved on the selected gpu in the created context. + """ + + +class DeviceAllocatorContextManagerV2(_DeviceAllocatorContextManagerAbs, ABC): + """Distributes processes over multiple GPUs. + + You can specify how many processes should be started on each GPU. + To use this context manager, you have to subclass it and override the abstract methods.""" + + def __init__(self, model_id: int, varargs: dict = None): + super().__init__(model_id, varargs) + self.dynamic_memory_growth_initialized = False + self.tf_device = None + + if self.gpu_memory_limit() is not None: + warnings.warn( + "The DeviceAllocatorContextManagerV2 require or support setting a gpu memory limit. " + "Instead, memory is grown dynamically as needed. (but only reduced when the " + "process is terminated)." + "Your implementation of `gpu_memory_limit` will be ignored.", + UserWarning, + ) + + @classmethod + def gpu_memory_limit(cls) -> Optional[int]: + """Not needed in DeviceAllocatorContextManagerV2 anymore. Ignored.""" + return None + + def _use_gpu(self, index: int): + if not self.dynamic_memory_growth_initialized: + DynamicGpuGrowthContextManager.enable_dynamic_gpu_growth() + self.dynamic_memory_growth_initialized = True + + self.tf_device = tf.device(f"gpu:{index}") + self.tf_device.__enter__() + + def __exit__(self, type, value, traceback) -> None: + super().__exit__(type, value, traceback) + if self.tf_device is not None: + self.tf_device.__exit__() + self.tf_device = None