Skip to content

Commit

Permalink
Support loading models trained with different model_parallel_world_si…
Browse files Browse the repository at this point in the history
…ze. (#16)

* temp save

* quick fix of demo memory issue

* Refactor tensor creation dtype / device control.

This commit makes two changes during model creation:
1. Decouples promote_trainable_params_to_fp32 from model __init__. This
   is to avoid casting to fp32 to save memory in inference-only mode
   (#4).
2. Use a context manager to manage default tensor type change. In the
   previous version, the default tensor type is reset to
   torch.FloatTensor after creating the vision model, which is
   technically incorrect and should be the previous default tensor type
   instead. We implement our own context manager because the official
   context managers seem to be incomplete at this time (PyTorch 2.0.1):
   No dtype manager is provided and set_default_device is ineffective to
   the torch.Tensor calls which are used in fairscale.

* Change CLIP dtype management in llama.py

It is probably safer to keep CLIP at its original precision (e.g., fp16)
regardless of the autocast setting: Some casting (e.g., from fp16 to
bf16) may be lossy and can potentially harm the pre-trained model.

Keep the changes to llama.py only at this moment since a lot of copy-
pasted codes may be refactored in the future (#3).

* Respect args.precision when saving checkpoints.

* Support checkpoint merge

Checkpoint merge is suported in misc/tensor_parallel.py. Merge requires
that the checkpoint_mp_world_size % mp_world_size == 0. Support for
split (i.e., when mp_world_size % checkpoint_mp_world_size == 0) and
redistribute (for general mp_world_size and checkpoint_mp_world_size
values) will be added in the future.

Also changing multi_turn demo to use the new loading function with merge
support.

* move printing trainable params

* move training model creation back to cpu

Closes #15, #13
  • Loading branch information
linziyi96 authored Aug 4, 2023
1 parent 89425cd commit ab84c96
Show file tree
Hide file tree
Showing 11 changed files with 347 additions and 58 deletions.
61 changes: 38 additions & 23 deletions accessory/demos/multi_turn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import gradio as gr

from util.misc import setup_for_distributed, load_pretrained
from util.misc import setup_for_distributed
from util.tensor_parallel import load_tensor_parallel_model
from util.tensor_type import default_tensor_type
from model.meta import MetaModel
from data.conversation.lib import conv_templates, SeparatorStyle

Expand Down Expand Up @@ -50,14 +52,18 @@ def model_worker(
# set the print behavior.
setup_for_distributed(rank == 0)

torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = MetaModel(
args.llama_type, args.llama_config, args.tokenizer_path,
with_visual=False, max_seq_len=args.model_max_seq_len,
)
torch.set_default_tensor_type(torch.FloatTensor)
target_dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
}[args.dtype]
with default_tensor_type(dtype=target_dtype, device="cuda"):
model = MetaModel(
args.llama_type, args.llama_config, args.tokenizer_path,
with_visual=False, max_seq_len=args.model_max_seq_len,
)
model.eval()
print(f"Loading pretrained weights from {args.pretrained_path}")
load_pretrained(args.pretrained_path, args.pretrained_type, model)
load_tensor_parallel_model(model, args.pretrained_path, args.pretrained_type)
print(f"Model = {str(model)}")

barrier.wait()
Expand All @@ -67,22 +73,29 @@ def model_worker(
for user, bot in chatbot:
conv.append_message(conv.roles[0], user)
conv.append_message(conv.roles[1], bot)
# print(conv.get_prompt())

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
for stream_response in model.stream_generate(
conv.get_prompt(), None,
max_gen_len, temperature, top_p
):
end_pos = stream_response['text'].find(conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2)
if end_pos != -1:
stream_response['text'] = stream_response['text'][:end_pos].rstrip()+"\n"
stream_response['end_of_content'] = True
if response_queue is not None:
response_queue.put(stream_response)

if stream_response['end_of_content']:
break
for stream_response in model.stream_generate(
conv.get_prompt(), None,
max_gen_len, temperature, top_p
):
conv_sep = conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2
end_pos = stream_response["text"].find(conv_sep)
if end_pos != -1:
stream_response["text"] = stream_response['text'][:end_pos].rstrip() + "\n"
stream_response["end_of_content"] = True

# keep a few characters if not end_of_content to avoid sending part of conv_sep
# before all of it is generated.
if not stream_response["end_of_content"]:
if len(stream_response["text"]) < len(conv_sep):
continue
stream_response["text"] = stream_response["text"][:-len(conv_sep)]

if response_queue is not None:
response_queue.put(stream_response)

if stream_response["end_of_content"]:
break


def gradio_worker(
Expand Down Expand Up @@ -178,6 +191,8 @@ def undo(chatbot):
help="A port used by the PyTorch distributed module to initialize.")
parser.add_argument("--master_addr", type=str, default="127.0.0.1",
help="An address used by the PyTorch distributed module to initialize.")
parser.add_argument("--dtype", type=str, choices=["fp16", "bf16"], default="bf16",
help="The dtype used for model weights and inference.")
args = parser.parse_args()

# check and setup gpu_ids to use
Expand Down
18 changes: 11 additions & 7 deletions accessory/main_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.tensor_type import default_tensor_type, promote_trainable_params_to_fp32
from model.meta import MetaModel
from engine_finetune import train_one_epoch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -150,8 +151,16 @@ def main(args):
dp_group = fs_init.get_data_parallel_group()

# define the model
model = MetaModel(args.llama_type, args.llama_config,
args.tokenizer_path, with_visual=not args.no_visual)
mixed_precision_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"tf32": torch.float32,
}[args.precision]
with default_tensor_type(dtype=mixed_precision_dtype, device="cpu"):
model = MetaModel(args.llama_type, args.llama_config,
args.tokenizer_path, with_visual=not args.no_visual)
promote_trainable_params_to_fp32(model)
misc.print_trainable_params(model)
print(f"load pretrained from {args.pretrained_path}")
misc.load_pretrained(args.pretrained_path, args.pretrained_type, model)
print("Unwrapped Model = %s" % str(model))
Expand All @@ -160,11 +169,6 @@ def main(args):
if args.resume:
misc.resume_stage1(args, model_without_FSDP=model)

mixed_precision_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"tf32": torch.float32,
}[args.precision]
TransformerBlock = type(model.llma.layers[0])
# ignored_named_parameters = {name: param for name, param in model.named_parameters() if not param.requires_grad}
# print(ignored_named_parameters.keys())
Expand Down
19 changes: 12 additions & 7 deletions accessory/main_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.tensor_type import default_tensor_type, promote_trainable_params_to_fp32
from model.meta import MetaModel
from engine_pretrain import train_one_epoch, val_one_epoch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -147,8 +148,16 @@ def main(args):
dp_group = fs_init.get_data_parallel_group()

# define the model
model = MetaModel(args.llama_type, args.llama_config,
args.tokenizer_path, with_visual=False)
mixed_precision_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"tf32": torch.float32,
}[args.precision]
with default_tensor_type(dtype=mixed_precision_dtype, device="cpu"):
model = MetaModel(args.llama_type, args.llama_config,
args.tokenizer_path, with_visual=False)
promote_trainable_params_to_fp32(model)
misc.print_trainable_params(model)
if args.pretrained_path:
print(f"load pretrained from {args.pretrained_path}")
misc.load_pretrained(args.pretrained_path, args.pretrained_type, model)
Expand All @@ -158,11 +167,7 @@ def main(args):
if args.resume:
misc.resume_stage1(args, model_without_FSDP=model)

mixed_precision_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"tf32": torch.float32,
}[args.precision]

TransformerBlock = type(model.llma.layers[0])

model = FSDP(
Expand Down
14 changes: 7 additions & 7 deletions accessory/model/LLM/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from apex.normalization import FusedRMSNorm as RMSNorm
import open_clip

from util.tensor_type import default_tensor_type
import configs.global_configs
if configs.global_configs.USE_FLASH_ATTENTION:
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -308,9 +309,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
self.cache_image_words = 0 # for inference
if with_visual:
print("build llama model with clip")
torch.set_default_tensor_type(torch.cuda.HalfTensor)
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
torch.set_default_tensor_type(torch.FloatTensor)
with default_tensor_type(dtype=torch.half):
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
for name, param in self.clip.named_parameters():
param.requires_grad = False
in_dim = self.clip.visual.proj.shape[1]
Expand All @@ -334,9 +334,7 @@ def get_trainable_params(self):
def set_default_trainability(self):
for key, value in self.named_parameters():
value.requires_grad = False
value.data = value.data.half()
for key, value in self.get_trainable_params().items():
value.data = value.data.float()
value.requires_grad = True


Expand Down Expand Up @@ -366,8 +364,10 @@ def clip_encode_image(self, x):


def encode_image(self, image):
# return self.patch_embed(image)
image_tokens = self.clip_encode_image(image)
with torch.cuda.amp.autocast(enabled=False):
image = image.half()
image_tokens = self.clip_encode_image(image)
image = image.to(self.clip_proj.weight.dtype)
image_tokens = self.clip_proj_norm(self.clip_proj(image_tokens))
return image_tokens

Expand Down
8 changes: 3 additions & 5 deletions accessory/model/LLM/llama_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import configs.global_configs
if configs.global_configs.USE_FLASH_ATTENTION:
from flash_attn import flash_attn_func
from util.tensor_type import default_tensor_type

default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5))

Expand Down Expand Up @@ -349,9 +350,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
self.image_words = 0
if with_visual:
print("build llama model with clip")
torch.set_default_tensor_type(torch.cuda.HalfTensor)
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
torch.set_default_tensor_type(torch.FloatTensor)
with default_tensor_type(dtype=torch.half):
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
for name, param in self.clip.named_parameters():
param.requires_grad = False
in_dim = self.clip.visual.proj.shape[1]
Expand Down Expand Up @@ -401,9 +401,7 @@ def get_trainable_params(self):
def set_default_trainability(self):
for key, value in self.named_parameters():
value.requires_grad = False
value.data = value.data.half()
for key, value in self.get_trainable_params().items():
value.data = value.data.float()
value.requires_grad = True


Expand Down
8 changes: 3 additions & 5 deletions accessory/model/LLM/llama_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ColumnParallelLinear
)
from ..peft import LoraColumnParallelLinear, LoraRowParallelLinear
from util.tensor_type import default_tensor_type

from apex.normalization import FusedRMSNorm as RMSNorm
import open_clip
Expand Down Expand Up @@ -323,9 +324,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
self.cache_image_words = 0 # for inference
if with_visual:
print("build llama model with clip")
torch.set_default_tensor_type(torch.cuda.HalfTensor)
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
torch.set_default_tensor_type(torch.FloatTensor)
with default_tensor_type(dtype=torch.half):
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
for name, param in self.clip.named_parameters():
param.requires_grad = False
in_dim = self.clip.visual.proj.shape[1]
Expand All @@ -351,9 +351,7 @@ def get_trainable_params(self):
def set_default_trainability(self):
for key, value in self.named_parameters():
value.requires_grad = False
value.data = value.data.half()
for key, value in self.get_trainable_params().items():
value.data = value.data.float()
value.requires_grad = True


Expand Down
2 changes: 0 additions & 2 deletions accessory/model/LLM/llama_qformerv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,7 @@ def get_trainable_params(self):
def set_default_trainability(self):
for key, value in self.named_parameters():
value.requires_grad = False
value.data = value.data.half()
for key, value in self.get_trainable_params().items():
value.data = value.data.float()
value.requires_grad = True


Expand Down
1 change: 0 additions & 1 deletion accessory/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(
for name, param in self.named_parameters():
is_model_parallel = getattr(param, "is_model_parallel", False)
if param.requires_grad:
print(f"Trainable param: {name}, local_size: {param.shape}, model_parallel: {is_model_parallel}, dtype: {param.dtype}")
if is_model_parallel:
param_count_all += param.numel() * fs_init.get_model_parallel_world_size()
else:
Expand Down
13 changes: 12 additions & 1 deletion accessory/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,13 @@ def _save_model():
model_trainable_params = model.get_trainable_params()
model_trainable_params = ['.'.join([_ for _ in key.split('.') if not _.startswith('_')])
for key in model_trainable_params.keys()]
save_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"tf32": torch.float,
}[args.precision]
consolidated_model_state_dict = {
"model": {key: val.half() for key, val in model.state_dict().items() if key in model_trainable_params},
"model": {key: val.to(save_dtype) for key, val in model.state_dict().items() if key in model_trainable_params},
}
save_path = os.path.join(
save_dir,
Expand Down Expand Up @@ -608,3 +613,9 @@ def mark_mp_params(model: torch.nn.Module):
if isinstance(m, ParallelEmbedding):
m.weight.is_model_parallel = True


def print_trainable_params(model: torch.nn.Module) -> None:
for name, param in model.named_parameters():
is_model_parallel = getattr(param, "is_model_parallel", False)
print(f"Trainable param: {name}, local_size: {param.shape}, model_parallel: {is_model_parallel}, dtype: {param.dtype}")

Loading

0 comments on commit ab84c96

Please sign in to comment.