Skip to content

Commit

Permalink
🐛 fix multi-gpu issues (#156)
Browse files Browse the repository at this point in the history
Introduces a new class DeviceAllocatorContextManagerV2 which does not require experimental APIs anymore and uses dynamic memory growth. It fully backwards compatible: just replacing the class extension form DeviceAllocatorContextManagerV2 with DeviceAllocatorContextManagerV2 does the full migration. The gpu_memory_limit function is not called anymore and can be removed.

This closes #75
  • Loading branch information
MiWeiss authored Feb 16, 2023
1 parent 1731d32 commit 738ed10
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 79 deletions.
71 changes: 47 additions & 24 deletions examples/multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import uncertainty_wizard as uwiz


class MultiGpuContext(uwiz.models.ensemble_utils.DeviceAllocatorContextManager):

class MultiGpuContext(uwiz.models.ensemble_utils.DeviceAllocatorContextManagerV2):
@classmethod
def file_path(cls) -> str:
return "temp-ensemble.txt"
Expand All @@ -27,43 +26,63 @@ def virtual_devices_per_gpu(cls) -> Dict[int, int]:
# Here, we configure a setting with two gpus
# On gpu 0, two atomic models will be executed at the same time
# On gpu 1, three atomic models will be executed at the same time
return {
0: 2,
1: 3
}

@classmethod
def gpu_memory_limit(cls) -> int:
return 1500
return {0: 2, 1: 3}


def train_model(model_id):
import tensorflow as tf

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same',
input_shape=(32, 32, 3)))
model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(
tf.keras.layers.Conv2D(
32,
kernel_size=(3, 3),
activation="relu",
padding="same",
input_shape=(32, 32, 3),
)
)
model.add(
tf.keras.layers.Conv2D(
32, kernel_size=(3, 3), activation="relu", padding="same"
)
)
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(
tf.keras.layers.Conv2D(
64, kernel_size=(3, 3), activation="relu", padding="same"
)
)
model.add(
tf.keras.layers.Conv2D(
64, kernel_size=(3, 3), activation="relu", padding="same"
)
)
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same'))
model.add(
tf.keras.layers.Conv2D(
128, kernel_size=(3, 3), activation="relu", padding="same"
)
)
model.add(
tf.keras.layers.Conv2D(
128, kernel_size=(3, 3), activation="relu", padding="same"
)
)
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(128, activation="relu"))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.add(tf.keras.layers.Dense(10, activation="softmax"))

opt = tf.keras.optimizers.SGD(lr=0.001, momentum=0.9)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["accuracy"])

(x_train, y_train), (_, _) = tf.keras.datasets.cifar10.load_data()
x_train = x_train / 255.
x_train = x_train / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10)

# For the sake of this example, let's use just one epoch.
Expand All @@ -73,16 +92,20 @@ def train_model(model_id):
return model, "history_not_returned"


if __name__ == '__main__':
if __name__ == "__main__":
# Make sure the training data is cached on the fs before the multiprocessing starts
# Otherwise, all processes will simultaneously attempt to download and cache data,
# which will fail as they break each others caches
tensorflow.keras.datasets.cifar10.load_data()

# set this path to where you want to save the ensemble
temp_dir = "/tmp/ensemble"
ensemble = uwiz.models.LazyEnsemble(num_models=20, model_save_path=temp_dir, delete_existing=True,
default_num_processes=5)
ensemble = uwiz.models.LazyEnsemble(
num_models=20,
model_save_path=temp_dir,
delete_existing=True,
default_num_processes=5,
)
ensemble.create(train_model, context=MultiGpuContext)

print("Ensemble was successfully trained")
2 changes: 2 additions & 0 deletions uncertainty_wizard/models/ensemble_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
"DynamicGpuGrowthContextManager",
"NoneContextManager",
"DeviceAllocatorContextManager",
"DeviceAllocatorContextManagerV2",
"CpuOnlyContextManager",
"SaveConfig",
]

from ._lazy_contexts import (
CpuOnlyContextManager,
DeviceAllocatorContextManager,
DeviceAllocatorContextManagerV2,
DynamicGpuGrowthContextManager,
EnsembleContextManager,
NoneContextManager,
Expand Down
170 changes: 115 additions & 55 deletions uncertainty_wizard/models/ensemble_utils/_lazy_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import os
import pickle
import time
from typing import Dict
import warnings
from abc import ABC
from typing import Dict, Optional

import tensorflow as tf

Expand Down Expand Up @@ -219,23 +221,11 @@ def enable_dynamic_gpu_growth(cls):
global device_id


class DeviceAllocatorContextManager(EnsembleContextManager, abc.ABC):
"""
This context manager configures tensorflow such a user-defined amount of processes for every available gpu
are started. In addition, running a process on the CPU can be enabled.
This is an abstract context manager. To use it, one has to subclass it and override (at least)
the abstract methods.
"""
class _DeviceAllocatorContextManagerAbs(EnsembleContextManager, abc.ABC):
"""Abstract parent class for all managers that allocate processes to specific devices."""

def __init__(self):
super().__init__()
if not current_tf_version_is_older_than("2.10.0"):
raise RuntimeError(
"The DeviceAllocatorContextManager is not compatible with tensorflow 2.10.0 "
"or newer. Please fall back to a single GPU for now (see issue #75),"
"or downgrade to tensorflow 2.9.0."
)
def __init__(self, model_id: int, varargs: dict = None):
super().__init__(model_id, varargs)

# docstr-coverage: inherited
def __enter__(self) -> "DeviceAllocatorContextManager":
Expand All @@ -255,7 +245,6 @@ def __enter__(self) -> "DeviceAllocatorContextManager":
# docstr-coverage: inherited
def __exit__(self, type, value, traceback) -> None:
super().__exit__(type, value, traceback)

global number_of_tasks_in_this_process
global device_id
if number_of_tasks_in_this_process == self.max_sequential_tasks_per_process():
Expand Down Expand Up @@ -333,18 +322,6 @@ def virtual_devices_per_gpu(cls) -> Dict[int, int]:
:return: A mapping specifying how many processes of this ensemble should run concurrently per gpu.
"""

@classmethod
@abc.abstractmethod
def gpu_memory_limit(cls) -> int:
"""
Override this method to specify the amount of MB which should be used
when creating the virtual device on the GPU. Ignored for CPUs.
*Attention:* This function must be pure: Repeated calls should always return the same value.
:return: The amount of MB which will be reserved on the selected gpu in the created context.
"""

@classmethod
def acquire_lock_timeout(cls) -> int:
"""
Expand Down Expand Up @@ -433,31 +410,9 @@ def _pick_device(cls, availablilities) -> int:
print(f"Availabilities: {availablilities}. Picked Device {picked_device}")
return picked_device

@classmethod
def _use_gpu(cls, index: int):
size = cls.gpu_memory_limit()
gpus = tf.config.experimental.list_physical_devices("GPU")

# Check if selected gpu can be found
if gpus is None or len(gpus) <= index:
raise ValueError(
f"Uncertainty Wizards DeviceAllocatorContextManager was configured to use gpu {index} "
f"but no no such gpu was found. "
)

try:
tf.config.set_visible_devices([gpus[index]], "GPU")
tf.config.experimental.set_virtual_device_configuration(
gpus[index],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=size)],
)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
raise ValueError(
f"Uncertainty Wizard was unable to create a virtual device "
f"on gpu {index} and memory limit {size}MB"
) from e
@abc.abstractmethod
def _use_gpu(self, index: int):
pass

@classmethod
def _acquire_lock(cls) -> int:
Expand Down Expand Up @@ -501,3 +456,108 @@ def _acquire_lock(cls) -> int:
def _release_lock(cls, lockfile: int):
os.close(lockfile)
os.remove(cls._lock_file_path())


class DeviceAllocatorContextManager(_DeviceAllocatorContextManagerAbs, ABC):
"""DEPRECATED. Please use DeviceAllocatorContextManagerV2 instead.
This context manager configures tensorflow such a user-defined amount of processes for every available gpu
are started. In addition, running a process on the CPU can be enabled.
This is an abstract context manager. To use it, one has to subclass it and override (at least)
the abstract methods.
"""

def __init__(self, model_id: int, varargs: dict = None):
super().__init__(model_id, varargs)
if not current_tf_version_is_older_than("2.10.0"):
raise RuntimeError(
"The DeviceAllocatorContextManager is not compatible with tensorflow 2.10.0 "
"or newer. Please use DeviceAllocatorContextManagerV2 instead."
)

warnings.warn(
"DeviceAllocatorContextManager is deprecated. "
"Please use DeviceAllocatorContextManagerV2 instead. "
"Migration is easy, just extend DeviceAllocatorContextManagerV2 "
"instead of DeviceAllocatorContextManager. "
"and remove the `gpu_memory_limit` method from your extension. ",
DeprecationWarning,
)

def _use_gpu(self, index: int):
size = self.gpu_memory_limit()
gpus = tf.config.experimental.list_physical_devices("GPU")

# Check if selected gpu can be found
if gpus is None or len(gpus) <= index:
raise ValueError(
f"Uncertainty Wizards DeviceAllocatorContextManager was configured to use gpu {index} "
f"but no no such gpu was found. "
)

try:
tf.config.set_visible_devices([gpus[index]], "GPU")
tf.config.experimental.set_virtual_device_configuration(
gpus[index],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=size)],
)
logical_gpus = tf.config.experimental.list_logical_devices("GPU")
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
raise ValueError(
f"Uncertainty Wizard was unable to create a virtual device "
f"on gpu {index} and memory limit {size}MB"
) from e

@classmethod
@abc.abstractmethod
def gpu_memory_limit(cls) -> Optional[int]:
"""
Override this method to specify the amount of MB which should be used
when creating the virtual device on the GPU. Ignored for CPUs.
*Attention:* This function must be pure: Repeated calls should always return the same value.
:return: The amount of MB which will be reserved on the selected gpu in the created context.
"""


class DeviceAllocatorContextManagerV2(_DeviceAllocatorContextManagerAbs, ABC):
"""Distributes processes over multiple GPUs.
You can specify how many processes should be started on each GPU.
To use this context manager, you have to subclass it and override the abstract methods."""

def __init__(self, model_id: int, varargs: dict = None):
super().__init__(model_id, varargs)
self.dynamic_memory_growth_initialized = False
self.tf_device = None

if self.gpu_memory_limit() is not None:
warnings.warn(
"The DeviceAllocatorContextManagerV2 require or support setting a gpu memory limit. "
"Instead, memory is grown dynamically as needed. (but only reduced when the "
"process is terminated)."
"Your implementation of `gpu_memory_limit` will be ignored.",
UserWarning,
)

@classmethod
def gpu_memory_limit(cls) -> Optional[int]:
"""Not needed in DeviceAllocatorContextManagerV2 anymore. Ignored."""
return None

def _use_gpu(self, index: int):
if not self.dynamic_memory_growth_initialized:
DynamicGpuGrowthContextManager.enable_dynamic_gpu_growth()
self.dynamic_memory_growth_initialized = True

self.tf_device = tf.device(f"gpu:{index}")
self.tf_device.__enter__()

def __exit__(self, type, value, traceback) -> None:
super().__exit__(type, value, traceback)
if self.tf_device is not None:
self.tf_device.__exit__()
self.tf_device = None

0 comments on commit 738ed10

Please sign in to comment.