Skip to content

Commit

Permalink
[lazy] support from_pretrained (hpcaitech#4801)
Browse files Browse the repository at this point in the history
* [lazy] patch from pretrained

* [lazy] fix from pretrained and add tests

* [devops] update ci
  • Loading branch information
ver217 authored Sep 26, 2023
1 parent 64a08b2 commit 4965c0d
Show file tree
Hide file tree
Showing 11 changed files with 397 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 60
defaults:
run:
Expand Down Expand Up @@ -214,6 +214,7 @@ jobs:
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
TESTMON_CORE_PKGS: /__w/ColossalAI/ColossalAI/requirements/requirements.txt,/__w/ColossalAI/ColossalAI/requirements/requirements-test.txt
LLAMA_PATH: /data/scratch/llama-tiny

- name: Store Testmon Cache
run: |
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: [self-hosted, 8-gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 40
steps:
- name: Check GPU Availability # ensure all GPUs have enough memory
Expand Down Expand Up @@ -64,6 +64,7 @@ jobs:
env:
DATA: /data/scratch/cifar-10
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny

- name: Notify Lark
id: message-preparation
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/compatiblity_test_on_dispatch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
steps:
- name: Install dependencies
Expand Down Expand Up @@ -92,3 +92,4 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
3 changes: 2 additions & 1 deletion .github/workflows/compatiblity_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
Expand Down Expand Up @@ -87,3 +87,4 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
3 changes: 2 additions & 1 deletion .github/workflows/compatiblity_test_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 120
steps:
- name: Install dependencies
Expand Down Expand Up @@ -85,6 +85,7 @@ jobs:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny

- name: Notify Lark
id: message-preparation
Expand Down
8 changes: 8 additions & 0 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper

Expand Down Expand Up @@ -131,6 +132,7 @@ def boost(
"""
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(FrankLeeeee): consider multi-dataloader case
pretrained_path = pretrained_utils.get_pretrained_path(model)
# transform model for mixed precision
if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
Expand All @@ -146,6 +148,12 @@ def boost(
# when mixed_precision is specified and the plugin is not given or does not control the precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)

if pretrained_path:
self.load_model(model, pretrained_path)
# clear pretrained path attr
orig_model = model.unwrap() if isinstance(model, ModelWrapper) else model
pretrained_utils.set_pretrained_path(orig_model, None)

return model, optimizer, criterion, dataloader, lr_scheduler

def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
Expand Down
16 changes: 16 additions & 0 deletions colossalai/interface/pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional

from torch.nn import Module

__all__ = [
"get_pretrained_path",
"set_pretrained_path",
]


def get_pretrained_path(model: Module) -> Optional[str]:
return getattr(model, "_pretrained", None)


def set_pretrained_path(model: Module, path: str) -> None:
setattr(model, "_pretrained", path)
3 changes: 3 additions & 0 deletions colossalai/lazy/lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from colossalai.logging import get_dist_logger

from .construction import ConstructorManager
from .pretrained import PretrainedManager

import colossalai._analyzer._subclasses._meta_registration # noqa

Expand Down Expand Up @@ -595,11 +596,13 @@ def wrapper(*args, **kwargs):
)

ConstructorManager.apply(overrides)
PretrainedManager.inject()

def __exit__(self, exc_type, exc_val, exc_tb):
self.tensor_cls.default_device = self.old_default_device
LazyInitContext._replaced = False
ConstructorManager.clear()
PretrainedManager.recover()

@staticmethod
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
Expand Down
Loading

0 comments on commit 4965c0d

Please sign in to comment.