diff --git a/accessory/demos/multi_turn.py b/accessory/demos/multi_turn.py index a818bb1c..f4ff5a6f 100644 --- a/accessory/demos/multi_turn.py +++ b/accessory/demos/multi_turn.py @@ -14,7 +14,9 @@ import gradio as gr -from util.misc import setup_for_distributed, load_pretrained +from util.misc import setup_for_distributed +from util.tensor_parallel import load_tensor_parallel_model +from util.tensor_type import default_tensor_type from model.meta import MetaModel from data.conversation.lib import conv_templates, SeparatorStyle @@ -50,14 +52,18 @@ def model_worker( # set the print behavior. setup_for_distributed(rank == 0) - torch.set_default_tensor_type(torch.cuda.HalfTensor) - model = MetaModel( - args.llama_type, args.llama_config, args.tokenizer_path, - with_visual=False, max_seq_len=args.model_max_seq_len, - ) - torch.set_default_tensor_type(torch.FloatTensor) + target_dtype = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + }[args.dtype] + with default_tensor_type(dtype=target_dtype, device="cuda"): + model = MetaModel( + args.llama_type, args.llama_config, args.tokenizer_path, + with_visual=False, max_seq_len=args.model_max_seq_len, + ) + model.eval() print(f"Loading pretrained weights from {args.pretrained_path}") - load_pretrained(args.pretrained_path, args.pretrained_type, model) + load_tensor_parallel_model(model, args.pretrained_path, args.pretrained_type) print(f"Model = {str(model)}") barrier.wait() @@ -67,22 +73,29 @@ def model_worker( for user, bot in chatbot: conv.append_message(conv.roles[0], user) conv.append_message(conv.roles[1], bot) - # print(conv.get_prompt()) - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - for stream_response in model.stream_generate( - conv.get_prompt(), None, - max_gen_len, temperature, top_p - ): - end_pos = stream_response['text'].find(conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2) - if end_pos != -1: - stream_response['text'] = stream_response['text'][:end_pos].rstrip()+"\n" - stream_response['end_of_content'] = True - if response_queue is not None: - response_queue.put(stream_response) - - if stream_response['end_of_content']: - break + for stream_response in model.stream_generate( + conv.get_prompt(), None, + max_gen_len, temperature, top_p + ): + conv_sep = conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2 + end_pos = stream_response["text"].find(conv_sep) + if end_pos != -1: + stream_response["text"] = stream_response['text'][:end_pos].rstrip() + "\n" + stream_response["end_of_content"] = True + + # keep a few characters if not end_of_content to avoid sending part of conv_sep + # before all of it is generated. + if not stream_response["end_of_content"]: + if len(stream_response["text"]) < len(conv_sep): + continue + stream_response["text"] = stream_response["text"][:-len(conv_sep)] + + if response_queue is not None: + response_queue.put(stream_response) + + if stream_response["end_of_content"]: + break def gradio_worker( @@ -178,6 +191,8 @@ def undo(chatbot): help="A port used by the PyTorch distributed module to initialize.") parser.add_argument("--master_addr", type=str, default="127.0.0.1", help="An address used by the PyTorch distributed module to initialize.") + parser.add_argument("--dtype", type=str, choices=["fp16", "bf16"], default="bf16", + help="The dtype used for model weights and inference.") args = parser.parse_args() # check and setup gpu_ids to use diff --git a/accessory/main_finetune.py b/accessory/main_finetune.py index 5638801e..5a19d673 100644 --- a/accessory/main_finetune.py +++ b/accessory/main_finetune.py @@ -32,6 +32,7 @@ import util.misc as misc from util.misc import NativeScalerWithGradNormCount as NativeScaler +from util.tensor_type import default_tensor_type, promote_trainable_params_to_fp32 from model.meta import MetaModel from engine_finetune import train_one_epoch from torch.utils.data import Dataset @@ -150,8 +151,16 @@ def main(args): dp_group = fs_init.get_data_parallel_group() # define the model - model = MetaModel(args.llama_type, args.llama_config, - args.tokenizer_path, with_visual=not args.no_visual) + mixed_precision_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "tf32": torch.float32, + }[args.precision] + with default_tensor_type(dtype=mixed_precision_dtype, device="cpu"): + model = MetaModel(args.llama_type, args.llama_config, + args.tokenizer_path, with_visual=not args.no_visual) + promote_trainable_params_to_fp32(model) + misc.print_trainable_params(model) print(f"load pretrained from {args.pretrained_path}") misc.load_pretrained(args.pretrained_path, args.pretrained_type, model) print("Unwrapped Model = %s" % str(model)) @@ -160,11 +169,6 @@ def main(args): if args.resume: misc.resume_stage1(args, model_without_FSDP=model) - mixed_precision_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "tf32": torch.float32, - }[args.precision] TransformerBlock = type(model.llma.layers[0]) # ignored_named_parameters = {name: param for name, param in model.named_parameters() if not param.requires_grad} # print(ignored_named_parameters.keys()) diff --git a/accessory/main_pretrain.py b/accessory/main_pretrain.py index bc5ede24..5359ee59 100644 --- a/accessory/main_pretrain.py +++ b/accessory/main_pretrain.py @@ -32,6 +32,7 @@ import util.misc as misc from util.misc import NativeScalerWithGradNormCount as NativeScaler +from util.tensor_type import default_tensor_type, promote_trainable_params_to_fp32 from model.meta import MetaModel from engine_pretrain import train_one_epoch, val_one_epoch from torch.utils.data import Dataset @@ -147,8 +148,16 @@ def main(args): dp_group = fs_init.get_data_parallel_group() # define the model - model = MetaModel(args.llama_type, args.llama_config, - args.tokenizer_path, with_visual=False) + mixed_precision_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "tf32": torch.float32, + }[args.precision] + with default_tensor_type(dtype=mixed_precision_dtype, device="cpu"): + model = MetaModel(args.llama_type, args.llama_config, + args.tokenizer_path, with_visual=False) + promote_trainable_params_to_fp32(model) + misc.print_trainable_params(model) if args.pretrained_path: print(f"load pretrained from {args.pretrained_path}") misc.load_pretrained(args.pretrained_path, args.pretrained_type, model) @@ -158,11 +167,7 @@ def main(args): if args.resume: misc.resume_stage1(args, model_without_FSDP=model) - mixed_precision_dtype = { - "fp16": torch.float16, - "bf16": torch.bfloat16, - "tf32": torch.float32, - }[args.precision] + TransformerBlock = type(model.llma.layers[0]) model = FSDP( diff --git a/accessory/model/LLM/llama.py b/accessory/model/LLM/llama.py index d763cf8a..4f1d8877 100644 --- a/accessory/model/LLM/llama.py +++ b/accessory/model/LLM/llama.py @@ -20,6 +20,7 @@ from apex.normalization import FusedRMSNorm as RMSNorm import open_clip +from util.tensor_type import default_tensor_type import configs.global_configs if configs.global_configs.USE_FLASH_ATTENTION: from flash_attn import flash_attn_func @@ -308,9 +309,8 @@ def __init__(self, params: ModelArgs, with_visual=False): self.cache_image_words = 0 # for inference if with_visual: print("build llama model with clip") - torch.set_default_tensor_type(torch.cuda.HalfTensor) - self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') - torch.set_default_tensor_type(torch.FloatTensor) + with default_tensor_type(dtype=torch.half): + self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') for name, param in self.clip.named_parameters(): param.requires_grad = False in_dim = self.clip.visual.proj.shape[1] @@ -334,9 +334,7 @@ def get_trainable_params(self): def set_default_trainability(self): for key, value in self.named_parameters(): value.requires_grad = False - value.data = value.data.half() for key, value in self.get_trainable_params().items(): - value.data = value.data.float() value.requires_grad = True @@ -366,8 +364,10 @@ def clip_encode_image(self, x): def encode_image(self, image): - # return self.patch_embed(image) - image_tokens = self.clip_encode_image(image) + with torch.cuda.amp.autocast(enabled=False): + image = image.half() + image_tokens = self.clip_encode_image(image) + image = image.to(self.clip_proj.weight.dtype) image_tokens = self.clip_proj_norm(self.clip_proj(image_tokens)) return image_tokens diff --git a/accessory/model/LLM/llama_adapter.py b/accessory/model/LLM/llama_adapter.py index 289a15c8..ec59d441 100644 --- a/accessory/model/LLM/llama_adapter.py +++ b/accessory/model/LLM/llama_adapter.py @@ -24,6 +24,7 @@ import configs.global_configs if configs.global_configs.USE_FLASH_ATTENTION: from flash_attn import flash_attn_func +from util.tensor_type import default_tensor_type default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5)) @@ -349,9 +350,8 @@ def __init__(self, params: ModelArgs, with_visual=False): self.image_words = 0 if with_visual: print("build llama model with clip") - torch.set_default_tensor_type(torch.cuda.HalfTensor) - self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') - torch.set_default_tensor_type(torch.FloatTensor) + with default_tensor_type(dtype=torch.half): + self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') for name, param in self.clip.named_parameters(): param.requires_grad = False in_dim = self.clip.visual.proj.shape[1] @@ -401,9 +401,7 @@ def get_trainable_params(self): def set_default_trainability(self): for key, value in self.named_parameters(): value.requires_grad = False - value.data = value.data.half() for key, value in self.get_trainable_params().items(): - value.data = value.data.float() value.requires_grad = True diff --git a/accessory/model/LLM/llama_peft.py b/accessory/model/LLM/llama_peft.py index 4a4199da..bec60658 100644 --- a/accessory/model/LLM/llama_peft.py +++ b/accessory/model/LLM/llama_peft.py @@ -17,6 +17,7 @@ ColumnParallelLinear ) from ..peft import LoraColumnParallelLinear, LoraRowParallelLinear +from util.tensor_type import default_tensor_type from apex.normalization import FusedRMSNorm as RMSNorm import open_clip @@ -323,9 +324,8 @@ def __init__(self, params: ModelArgs, with_visual=False): self.cache_image_words = 0 # for inference if with_visual: print("build llama model with clip") - torch.set_default_tensor_type(torch.cuda.HalfTensor) - self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') - torch.set_default_tensor_type(torch.FloatTensor) + with default_tensor_type(dtype=torch.half): + self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai') for name, param in self.clip.named_parameters(): param.requires_grad = False in_dim = self.clip.visual.proj.shape[1] @@ -351,9 +351,7 @@ def get_trainable_params(self): def set_default_trainability(self): for key, value in self.named_parameters(): value.requires_grad = False - value.data = value.data.half() for key, value in self.get_trainable_params().items(): - value.data = value.data.float() value.requires_grad = True diff --git a/accessory/model/LLM/llama_qformerv2.py b/accessory/model/LLM/llama_qformerv2.py index 7970bbd6..bc3124c9 100644 --- a/accessory/model/LLM/llama_qformerv2.py +++ b/accessory/model/LLM/llama_qformerv2.py @@ -337,9 +337,7 @@ def get_trainable_params(self): def set_default_trainability(self): for key, value in self.named_parameters(): value.requires_grad = False - value.data = value.data.half() for key, value in self.get_trainable_params().items(): - value.data = value.data.float() value.requires_grad = True diff --git a/accessory/model/meta.py b/accessory/model/meta.py index a5e5a478..093d608f 100644 --- a/accessory/model/meta.py +++ b/accessory/model/meta.py @@ -48,7 +48,6 @@ def __init__( for name, param in self.named_parameters(): is_model_parallel = getattr(param, "is_model_parallel", False) if param.requires_grad: - print(f"Trainable param: {name}, local_size: {param.shape}, model_parallel: {is_model_parallel}, dtype: {param.dtype}") if is_model_parallel: param_count_all += param.numel() * fs_init.get_model_parallel_world_size() else: diff --git a/accessory/util/misc.py b/accessory/util/misc.py index 5403eca0..9381393d 100644 --- a/accessory/util/misc.py +++ b/accessory/util/misc.py @@ -351,8 +351,13 @@ def _save_model(): model_trainable_params = model.get_trainable_params() model_trainable_params = ['.'.join([_ for _ in key.split('.') if not _.startswith('_')]) for key in model_trainable_params.keys()] + save_dtype = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "tf32": torch.float, + }[args.precision] consolidated_model_state_dict = { - "model": {key: val.half() for key, val in model.state_dict().items() if key in model_trainable_params}, + "model": {key: val.to(save_dtype) for key, val in model.state_dict().items() if key in model_trainable_params}, } save_path = os.path.join( save_dir, @@ -608,3 +613,9 @@ def mark_mp_params(model: torch.nn.Module): if isinstance(m, ParallelEmbedding): m.weight.is_model_parallel = True + +def print_trainable_params(model: torch.nn.Module) -> None: + for name, param in model.named_parameters(): + is_model_parallel = getattr(param, "is_model_parallel", False) + print(f"Trainable param: {name}, local_size: {param.shape}, model_parallel: {is_model_parallel}, dtype: {param.dtype}") + diff --git a/accessory/util/tensor_parallel.py b/accessory/util/tensor_parallel.py new file mode 100644 index 00000000..4a990d60 --- /dev/null +++ b/accessory/util/tensor_parallel.py @@ -0,0 +1,195 @@ +from collections import OrderedDict +import os +import re +from typing import Dict, List, NamedTuple, Tuple, Type + +import torch +import torch.nn as nn + +import fairscale.nn.model_parallel.initialize as fs_init +from fairscale.nn.model_parallel.layers import ( + ColumnParallelLinear, + RowParallelLinear, + ParallelEmbedding, +) + +r"""_MODEL_PARALLEL_MODULES defines a list of module classes that contains tensor-parallel +parameters which may need special handling. + +Each item is a pair whose first item is the module class, and the second item is a dictionary +defining along which dim is each of its weights splitted. + +_MODEL_PARALLEL_MODULES is defined as a ``list`` instead of a ``dict`` for well-defined matching +priority: The matching process is expected to be in the order defined in the list and exit on the +first match as returned by ``isinstance``. The design is to handle module sub-classing: Any subclass +of the defined classes can also be matched, and any different handling of the subclass should be +defined BEFORE the item of the parent class. + +To correctly save and load the checkpoints we expect each newly involved tensor parallel layer +to be registered in this list. +""" +_MODEL_PARALLEL_MODULES: List[Tuple[Type[nn.Module], Dict[str, int]]] = [ + (ColumnParallelLinear, {"weight": 0, "bias": 0}), + (RowParallelLinear, {"weight": 1, "bias": -1}), + (ParallelEmbedding, {"weight": 1}), +] + +def _tensor_list_max_diff(tensors: List[torch.Tensor]) -> float: + for tensor in tensors[1:]: + assert tensor.dtype is tensors[0].dtype and tensor.size() == tensors[0].size() + + if tensors[0].is_complex(): + max_diff = 0. + for i in range(len(tensors)): + for j in range(i + 1, len(tensors)): + max_diff = max(max_diff, (tensors[i] - tensors[j]).abs().max().item()) + return max_diff + + if not tensors[0].is_floating_point(): + tensors = [tensor.float() for tensor in tensors] + max_tensor, min_tensor = tensors[0].clone(), tensors[0].clone() + for tensor in tensors[1:]: + max_tensor = torch.maximum(tensor, max_tensor) + min_tensor = torch.minimum(tensor, min_tensor) + return (max_tensor - min_tensor).max().item() + + +def _load_checkpoint_and_merge_ranks( + ckpt_files: List[str], weight_parallel_dim: Dict[str, int], verbose: bool = False, +) -> OrderedDict[str, torch.Tensor]: + mp_rank = fs_init.get_model_parallel_rank() + mp_world_size = fs_init.get_model_parallel_world_size() + ckpt_world_size = len(ckpt_files) + + assert ckpt_world_size % mp_world_size == 0 + local_num_shards = ckpt_world_size // mp_world_size + local_shard_st = local_num_shards * mp_rank + local_shard_ed = local_num_shards * (mp_rank + 1) + ckpt_shards = [] + merged_ckpt = OrderedDict() + for shard_id in range(local_shard_st, local_shard_ed): + shard = torch.load(ckpt_files[shard_id], map_location="cpu") + if "model" in shard and isinstance(shard["model"], dict): + shard = shard["model"] + ckpt_shards.append(shard) + + for key in list(ckpt_shards[0].keys()): + param_shards = [shard[key] for shard in ckpt_shards] + if key not in weight_parallel_dim: # non tensor parallel parameter + max_diff = _tensor_list_max_diff(param_shards) + if max_diff > 0.: + print( + "WARNING! Found unequal replicas of non-tensor-parallel params: " + f"name={key}, ranks={','.join(str(x) for x in range(local_shard_st, local_shard_ed))}, " + f"max_diff={max_diff}.", + force=True, + ) + merged_ckpt[key] = param_shards[0] + else: + merged_ckpt[key] = torch.cat(param_shards, dim=weight_parallel_dim[key]) + + # delete the original weights to avoid 2x memory usage. + for shard in ckpt_shards: + del shard[key] + + return merged_ckpt + + +def _load_checkpoint_and_split_rank( + ckpt_files: List[str], weight_parallel_dim: Dict[str, int], verbose: bool = False, +) -> OrderedDict[str, torch.Tensor]: + raise NotImplementedError() + + +def _load_checkpoint_and_redistribute_general( + ckpt_files: List[str], weight_parallel_dim: Dict[str, int], verbose: bool = False, +) -> OrderedDict[str, torch.Tensor]: + raise NotImplementedError() + + +def load_tensor_parallel_model( + model: nn.Module, path: str, format: str, verbose: bool = False +) -> Tuple[List[str], List[str]]: + r"""This function loads tensor parallel checkpoints to a model. It handles different formats + (e.g., saved by different training frameworks or released by different organizations) and potentially + a change of tensor parallel size (e.g., reducing tensor parallel size when running on fewer GPUs + each with larger memory). + + Args: + model (nn.Module): The model to load the checkpoint into. + path (str): A path containing checkpoint files. + format (str): Format of the checkpoing files. Supported formats: ``consolidated`` (saved by our + framework) and ``meta_ori`` (original checkpoints released in Meta's LLaMA repo). + verbose (bool): Print verbose information about the loading process for debug purposes. + Default=``False``. + """ + + def print_if_verbose(*args, **kwargs): + if verbose: + print(*args, **kwargs) + + weight_parallel_dim = {} + for name, module in model.named_modules(): + for class_, dict_ in _MODEL_PARALLEL_MODULES: + if isinstance(module, class_): + for leaf_name, dim in dict_.items(): + full_name = name + "." + leaf_name if name else leaf_name + if dim >= 0: + weight_parallel_dim[full_name] = dim + break + + mp_world_size = fs_init.get_model_parallel_world_size() + + if format in ["meta_ori", "consolidated"]: + # meta_ori and consolidated are essentially the same format: Both store weights + # of each model parallel rank in a separate file. The minor differences are: + # 1. In "meta_ori" format, filenames contain only model_parallel_rank but in + # "consolidated" format, filenames also contain model_parallel_world_size to + # make a missing part of the checkpoints instantly noticeable. + # 2. In "consolidated" format, state keys additionally contain the "llma." prefix. + + # Integrity check and checkpoint mp_world_size calculation if needed. + if format == "meta_ori": + pattern = re.compile("^consolidated.(\d{2}).pth$") + else: + pattern = re.compile("^consolidated.(\d{2})-of-(\d{2}).model.pth$") + ckpt_fns = [fn for fn in os.listdir(path) if pattern.match(fn)] + ckpt_mp_world_size = len(ckpt_fns) + assert ckpt_mp_world_size > 0, ( + f"\"{path}\" is not a valid {format} format checkpoint path: " + "No file with valid name is found in the path." + ) + ckpt_files = [] + for i in range(ckpt_mp_world_size): + if format == "meta_ori": + fn = f"consolidated.{i:02d}.pth" + else: + fn = f"consolidated.{i:02d}-of-{ckpt_mp_world_size:02d}.model.pth" + full_path = os.path.join(path, fn) + assert os.path.isfile(full_path), f"\"{full_path}\" is not a file." + ckpt_files.append(full_path) + + # Dispatch to different implementations for better performance: Shorten the start-up + # time as much as possible because we strive for better user experience! + if ckpt_mp_world_size % mp_world_size == 0: + local_state_dict = _load_checkpoint_and_merge_ranks( + ckpt_files, weight_parallel_dim, verbose + ) + elif mp_world_size % ckpt_mp_world_size == 0: + local_state_dict = _load_checkpoint_and_split_rank( + ckpt_files, weight_parallel_dim, verbose + ) + else: + local_state_dict = _load_checkpoint_and_redistribute_general( + ckpt_files, weight_parallel_dim, verbose + ) + + if format == "meta_ori": + local_state_dict = OrderedDict( + ("llma." + key, value) for key, value in local_state_dict.items() + ) + + return model.load_state_dict(local_state_dict, strict=False) + + else: + raise NotImplementedError(f"Checkpoint format {format} is unknown.") diff --git a/accessory/util/tensor_type.py b/accessory/util/tensor_type.py new file mode 100644 index 00000000..e9ddb9fe --- /dev/null +++ b/accessory/util/tensor_type.py @@ -0,0 +1,66 @@ +from types import TracebackType +from typing import Any, Optional +import torch +import torch.nn as nn + + +class default_tensor_type: + _tensor_type_stack = [(torch.float, "cpu")] + + def __init__( + self, + dtype: Optional[torch.dtype] = None, + device: Optional[str] = None, + ) -> None: + # Only limited combinations are supported. + assert device is None or device in ["cpu", "cuda"] + assert dtype is None or dtype in [torch.float, torch.bfloat16, torch.half] + self.dtype, self.device = dtype, device + + def __enter__(self) -> None: + dtype, device = self.dtype, self.device + if dtype is None: + dtype = default_tensor_type._tensor_type_stack[-1][0] + if device is None: + device = default_tensor_type._tensor_type_stack[-1][1] + default_tensor_type._tensor_type_stack.append((dtype, device)) + + # We use all 3 calls since the new apis (set_default_device, set_default_dtype) + # seems to be ineffective sometimes (e.g., set_default_device is ineffective to + # torch.Tensor calls). + torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device)) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + default_tensor_type._tensor_type_stack.pop() + dtype, device = default_tensor_type._tensor_type_stack[-1] + + torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device)) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + @staticmethod + def get_tensor_type(dtype: torch.dtype, device: str) -> Any: + return { + (torch.float, "cpu"): torch.FloatTensor, + (torch.bfloat16, "cpu"): torch.BFloat16Tensor, + (torch.half, "cpu"): torch.HalfTensor, + (torch.float, "cuda"): torch.cuda.FloatTensor, + (torch.bfloat16, "cuda"): torch.cuda.BFloat16Tensor, + (torch.half, "cuda"): torch.cuda.HalfTensor, + }[(dtype, device)] + + +def promote_trainable_params_to_fp32(model: nn.Module) -> None: + for param in model.parameters(): + if param.requires_grad: + if param.is_floating_point() and torch.finfo(param.dtype).bits < 32: + param.data = param.data.float() + if param.is_complex() and torch.finfo(param.dtype).bits < 32: + param.data = param.data.to(torch.complex64) \ No newline at end of file