Skip to content

Commit

Permalink
Fix memory leak from _hp_mapping (#5643)
Browse files Browse the repository at this point in the history
See #5496 
I don't really know if this is a good solution
  • Loading branch information
chiragjn authored Jun 25, 2024
1 parent b3767d0 commit 821af15
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
19 changes: 17 additions & 2 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
setattr(sys.modules[__name__], 'fragment_address', fragment_address)


def print_rank_0(message, debug=False, force=False):
if dist.get_rank() == 0 and (debug or force):
print(message)


class BF16_Optimizer(ZeROOptimizer):

def __init__(self,
Expand Down Expand Up @@ -92,7 +97,16 @@ def __init__(self,
if self.using_real_optimizer:
self._setup_for_real_optimizer()

see_memory_usage('end bf16_optimizer', force=True)
see_memory_usage('end bf16_ optimizer', force=True)

def destroy(self):
for i, _ in enumerate(self.optimizer.param_groups):
for p in self.bf16_groups[i]:
if getattr(p, '_hp_mapping', None):
p._hp_mapping = None
for hook in self._grad_acc_hooks:
hook.remove()
print_rank_0("Removed grad acc hooks")

def _configure_moe_settings(self):
assert any(
Expand Down Expand Up @@ -187,6 +201,7 @@ def _setup_for_real_optimizer(self):
self.initialize_optimizer_states()
see_memory_usage('end initialize_optimizer', force=True)

self._grad_acc_hooks = []
if self.immediate_grad_update:
self.create_grad_acc_hooks()

Expand Down Expand Up @@ -541,7 +556,7 @@ def wrapper(param, i, j):
def accumulate_hp_grads_and_remove_lp(*notneeded):
self.accumulate_hp_grads_and_remove_lp(param, i, j)

grad_acc.register_hook(accumulate_hp_grads_and_remove_lp)
self._grad_acc_hooks.append(grad_acc.register_hook(accumulate_hp_grads_and_remove_lp))
self.grad_accs.append(grad_acc)

wrapper(param, i, j)
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,10 @@ def __init__(self,
self._param_slice_mappings = self._create_param_mapping()

def destroy(self):
for i, _ in enumerate(self.optimizer.param_groups):
for p in self.bit16_groups[i]:
if getattr(p, '_hp_mapping', None):
p._hp_mapping = None
for hook in self._grad_acc_hooks:
hook.remove()
self.print_rank_0("Removed grad acc hooks")
Expand Down

0 comments on commit 821af15

Please sign in to comment.